Jax
跳到导航
跳到搜索
安装
- jax和jaxlib版本要匹配(注意cuda支持)
- 我的CUDA版本是11.1 ,卡是A40,系统是Ubuntu,cudnn版本是805
- 安装最新版本 jax 0.2.26和jaxlib0.1.75后会在random函数报错“CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain”
- 最后安装的 jax0.2.2 (pip install -v jax==0.2.2), jaxlib是0.1.72 [3]解决问题
- 好像是因为一定要11.1的驱动11的不行(或者cudnn的问题)