fix(container): set correct PyTorch version not to override cuda

wheels

Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
aarnphm-ec2-dev
2023-08-16 10:44:23 +00:00
parent 904895f989
commit 3363ee158b

View File

@@ -2,7 +2,7 @@
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM debian:bullseye-slim as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTORCH_VERSION=2.0.1
ARG PYTHON_VERSION=3.9
ARG CUDA_VERSION=11.8
ARG MAMBA_VERSION=23.1.0-1
@@ -61,18 +61,18 @@ RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya
# NOTE: Build vllm CUDA kernels
# FROM kernel-builder as vllm-builder
#
# ENV COMMIT_HASH d1744376ae9fdbfa6a2dc763e1c67309e138fa3d
# ARG COMMIT_HASH=${COMMIT_HASH}
#
# WORKDIR /usr/src
#
# RUN <<EOT
# git clone https://github.com/vllm-project/vllm.git && cd vllm
# git fetch && git checkout ${COMMIT_HASH}
# python setup.py build
# EOT
FROM kernel-builder as vllm-builder
ENV COMMIT_HASH d1744376ae9fdbfa6a2dc763e1c67309e138fa3d
ARG COMMIT_HASH=${COMMIT_HASH}
WORKDIR /usr/src
RUN <<EOT
git clone https://github.com/vllm-project/vllm.git && cd vllm
git fetch && git checkout ${COMMIT_HASH}
python setup.py build
EOT
# NOTE: Build flash-attention-2 CUDA kernels
FROM kernel-builder as flash-attn-v2-builder
@@ -123,7 +123,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artefacts for vllm
# COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artefacts for flash-attention-v2
COPY --from=flash-attn-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
@@ -144,7 +144,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
rm -rf /var/lib/apt/lists/*
# Install all required dependencies
RUN --mount=type=cache,target=/root/.cache/pip pip install "ray==2.6.0" "einops" "vllm" "jax[cuda11_local]" "torch>=2.0.1" xformers -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ".[opt,mpt,fine-tune,llama,chatglm]" -v --no-cache-dir
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --extra-index-url "https://download.pytorch.org/whl/cu118" -v --no-cache-dir \
"ray==2.6.0" "einops" "torch>=2.0.1+cu118" xformers "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ".[opt,mpt,fine-tune,llama,chatglm]"
FROM base-container