PyTorch源码编译

安装conda

python最低要求 3.8版本及以上。

1
2
3
4
5
6
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm -rf ~/miniconda3/miniconda.sh

~/miniconda3/bin/conda init bash

下载源码

pytorch社区下载源码,并下载git的submodule。

1
2
3
4
5
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
# if you are updating an existing checkout
git submodule sync
git submodule update --init --recursive

拉代码特别是拉三方的库的时候,会有访问不到的情况,可以重试,或者配置http代理以及ssh代理:

http代理:

1
2
export http_proxy=host:ip
export https_proxy=host:ip

ssh代理:

1
2
3
4
5
cat ~/.ssh/config

Host github.com
User git
ProxyCommand nc -v -x localhost:7890 %h %p

安装依赖

pip最好配置下清华源,不然会很慢。

1
2
python -m pip install --upgrade pip
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple

安装pytorch依赖,gcc需要支持C++ 17,需要使用gcc 9.4.0及以上。

1
2
3
conda install cmake ninja
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below
pip install -r requirements.txt

编译

1
2
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
python setup.py develop

一些有用的编译时的环境变量:

  • DEBUG=1 will enable debug builds (-g -O0)
  • REL_WITH_DEB_INFO=1 will enable debug symbols with optimizations (-g -O3)
  • USE_DISTRIBUTED=0 will disable distributed (c10d, gloo, mpi, etc.) build.
  • USE_MKLDNN=0 will disable using MKL-DNN.
  • USE_CUDA=0 will disable compiling CUDA (in case you are developing on something not CUDA related), to save compile time.
  • BUILD_TEST=0 will disable building C++ test binaries.
  • USE_FBGEMM=0 will disable using FBGEMM (quantized 8-bit server operators).
  • USE_NNPACK=0 will disable compiling with NNPACK.
  • USE_QNNPACK=0 will disable QNNPACK build (quantized 8-bit operators).
  • USE_XNNPACK=0 will disable compiling with XNNPACK.
  • USE_FLASH_ATTENTION=0 and USE_MEM_EFF_ATTENTION=0 will disable compiling flash attention and memory efficient kernels respectively

验证

1
2
3
4
5
import torch

a = torch.randn(2,3)
b = torch.randn(2,3)
a + b

参考:

[1] PyTorch installation from source.

[2] PyTorch contributing doc.