| # syntax=docker/dockerfile:1 |
| |
| # NOTE: Building this image require's docker version >= 23.0. |
| # |
| # For reference: |
| # - https://docs.docker.com/build/dockerfile/frontend/#stable-channel |
| |
| ARG BASE_IMAGE=ubuntu:22.04 |
| ARG PYTHON_VERSION=3.11 |
| |
| FROM ${BASE_IMAGE} as dev-base |
| RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ |
| build-essential \ |
| ca-certificates \ |
| ccache \ |
| cmake \ |
| curl \ |
| git \ |
| libjpeg-dev \ |
| libpng-dev && \ |
| rm -rf /var/lib/apt/lists/* |
| RUN /usr/sbin/update-ccache-symlinks |
| RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache |
| ENV PATH /opt/conda/bin:$PATH |
| |
| FROM dev-base as conda |
| ARG PYTHON_VERSION=3.11 |
| # Automatically set by buildx |
| ARG TARGETPLATFORM |
| # translating Docker's TARGETPLATFORM into miniconda arches |
| RUN case ${TARGETPLATFORM} in \ |
| "linux/arm64") MINICONDA_ARCH=aarch64 ;; \ |
| *) MINICONDA_ARCH=x86_64 ;; \ |
| esac && \ |
| curl -fsSL -v -o ~/miniconda.sh -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-${MINICONDA_ARCH}.sh" |
| COPY requirements.txt . |
| # Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431 |
| RUN chmod +x ~/miniconda.sh && \ |
| bash ~/miniconda.sh -b -p /opt/conda && \ |
| rm ~/miniconda.sh && \ |
| /opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \ |
| /opt/conda/bin/python -mpip install -r requirements.txt && \ |
| /opt/conda/bin/conda clean -ya |
| |
| FROM dev-base as submodule-update |
| WORKDIR /opt/pytorch |
| COPY . . |
| RUN git submodule update --init --recursive |
| |
| FROM conda as build |
| ARG CMAKE_VARS |
| WORKDIR /opt/pytorch |
| COPY --from=conda /opt/conda /opt/conda |
| COPY --from=submodule-update /opt/pytorch /opt/pytorch |
| RUN make triton |
| RUN --mount=type=cache,target=/opt/ccache \ |
| export eval ${CMAKE_VARS} && \ |
| TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 8.9 9.0 9.0a" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ |
| CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \ |
| python setup.py install |
| |
| FROM conda as conda-installs |
| ARG PYTHON_VERSION=3.11 |
| ARG CUDA_PATH=cu121 |
| ARG CUDA_CHANNEL=nvidia |
| ARG INSTALL_CHANNEL=whl/nightly |
| # Automatically set by buildx |
| RUN /opt/conda/bin/conda update -y -n base -c defaults conda |
| RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} |
| |
| ARG TARGETPLATFORM |
| |
| # INSTALL_CHANNEL whl - release, whl/nightly - nightly, whle/test - test channels |
| RUN case ${TARGETPLATFORM} in \ |
| "linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio ;; \ |
| *) pip install --index-url https://download.pytorch.org/${INSTALL_CHANNEL}/${CUDA_PATH#.}/ torch torchvision torchaudio ;; \ |
| esac && \ |
| /opt/conda/bin/conda clean -ya |
| RUN /opt/conda/bin/pip install torchelastic |
| RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); \ |
| echo "Is torch compiled with cuda: ${IS_CUDA}"; \ |
| if test "${IS_CUDA}" != "True" -a ! -z "${CUDA_VERSION}"; then \ |
| exit 1; \ |
| fi |
| |
| FROM ${BASE_IMAGE} as official |
| ARG PYTORCH_VERSION |
| ARG TRITON_VERSION |
| ARG TARGETPLATFORM |
| ARG CUDA_VERSION |
| LABEL com.nvidia.volumes.needed="nvidia_driver" |
| RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ |
| ca-certificates \ |
| libjpeg-dev \ |
| libpng-dev \ |
| && rm -rf /var/lib/apt/lists/* |
| COPY --from=conda-installs /opt/conda /opt/conda |
| RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then \ |
| DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends gcc; \ |
| rm -rf /var/lib/apt/lists/*; \ |
| fi |
| ENV PATH /opt/conda/bin:$PATH |
| ENV NVIDIA_VISIBLE_DEVICES all |
| ENV NVIDIA_DRIVER_CAPABILITIES compute,utility |
| ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 |
| ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH |
| ENV PYTORCH_VERSION ${PYTORCH_VERSION} |
| WORKDIR /workspace |
| |
| FROM official as dev |
| # Should override the already installed version from the official-image stage |
| COPY --from=build /opt/conda /opt/conda |