generated from amazon-archives/__template_DevGuide
-
Notifications
You must be signed in to change notification settings - Fork 176
Open
Labels
Description
While attempting to install jax-neuronx and the associated dependencies on a trn1.2xlarge with Amazon Linux 2023, I encountered deviations from official documentation, and ultimately failed to install jax==0.4.31.
I followed the instructions at:
https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/jax/setup/jax-setup.html
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB
sudo yum update -y
sudo yum install kernel-devel-$(uname -r) kernel-headers-$(uname -r) -y
sudo yum install git -y
sudo yum install aws-neuronx-dkms-2.* -y
sudo yum install aws-neuronx-collectives-2.* -y
sudo yum install aws-neuronx-runtime-lib-2.* -y
sudo yum install aws-neuronx-tools-2.* -y
export PATH=/opt/aws/neuron/bin:$PATH
I confirmed the accelerator is visible:
neuron-ls
instance-type: trn1.2xlarge
instance-id: i-03a73ac140834eb76
+--------+--------+--------+--------------+
| NEURON | NEURON | NEURON | PCI |
| DEVICE | CORES | MEMORY | BDF |
+--------+--------+--------+--------------+
| 0 | 2 | 32 GB | 0000:00:1e.0 |
+--------+--------+--------+--------------+
The instruction says to run:
python3 -m pip
/usr/bin/python3: No module named pip
But I first had to install pip:
sudo yum install python3-pip -y
When I tried the recommended version combo, it fails
$ python3 -m pip install jax==0.4.31 jaxlib==0.4.31 jax-neuronx libneuronxla neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com
<=2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com
Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
ERROR: Could not find a version that satisfies the requirement jax==0.4.31 (from versions: 0.0, 0.1, 0.1.1, 0.1.2, 0.1.3, 0.1.4, 0.1.5, 0.1.6, 0.1.7, 0.1.8, 0.1.9, 0.1.10, 0.1.11, 0.1.12, 0.1.13, 0.1.14, 0.1.15, 0.1.16, 0.1.18, 0.1.19, 0.1.20, 0.1.21, 0.1.22, 0.1.23, 0.1.24, 0.1.25, 0.1.26, 0.1.27, 0.1.28, 0.1.29, 0.1.30, 0.1.31, 0.1.32, 0.1.33, 0.1.34, 0.1.35, 0.1.36, 0.1.37, 0.1.38, 0.1.39, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.45, 0.1.46, 0.1.47, 0.1.48, 0.1.49, 0.1.50, 0.1.51, 0.1.52, 0.1.53, 0.1.54, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.1.77, 0.2.0, 0.2.1, 0.2.2, 0.2.3, 0.2.4, 0.2.5, 0.2.6, 0.2.7, 0.2.8, 0.2.9, 0.2.10, 0.2.11, 0.2.12, 0.2.13, 0.2.14, 0.2.15, 0.2.16, 0.2.17, 0.2.18, 0.2.19, 0.2.20, 0.2.21, 0.2.22, 0.2.23, 0.2.24, 0.2.25, 0.2.26, 0.2.27, 0.2.28, 0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5, 0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11, 0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17, 0.3.18, 0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24, 0.3.25, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5, 0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.15, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29, 0.4.30)
ERROR: No matching distribution found for jax==0.4.31
python3 --version
Python 3.9.23