From 3363ee158b7d5da4798863dad345157e7b58f16c Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Wed, 16 Aug 2023 10:44:23 +0000 Subject: [PATCH] fix(container): set correct PyTorch version not to override cuda wheels Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- .../src/openllm/bundle/oci/Dockerfile | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/openllm-python/src/openllm/bundle/oci/Dockerfile b/openllm-python/src/openllm/bundle/oci/Dockerfile index 1e6b47ff..6f0ef484 100644 --- a/openllm-python/src/openllm/bundle/oci/Dockerfile +++ b/openllm-python/src/openllm/bundle/oci/Dockerfile @@ -2,7 +2,7 @@ # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM debian:bullseye-slim as pytorch-install -ARG PYTORCH_VERSION=2.0.0 +ARG PYTORCH_VERSION=2.0.1 ARG PYTHON_VERSION=3.9 ARG CUDA_VERSION=11.8 ARG MAMBA_VERSION=23.1.0-1 @@ -61,18 +61,18 @@ RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ /opt/conda/bin/conda clean -ya # NOTE: Build vllm CUDA kernels -# FROM kernel-builder as vllm-builder -# -# ENV COMMIT_HASH d1744376ae9fdbfa6a2dc763e1c67309e138fa3d -# ARG COMMIT_HASH=${COMMIT_HASH} -# -# WORKDIR /usr/src -# -# RUN <=2.0.1" xformers -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ".[opt,mpt,fine-tune,llama,chatglm]" -v --no-cache-dir +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --extra-index-url "https://download.pytorch.org/whl/cu118" -v --no-cache-dir \ + "ray==2.6.0" "einops" "torch>=2.0.1+cu118" xformers "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ".[opt,mpt,fine-tune,llama,chatglm]" FROM base-container