Skip to content

Conversation

@yhtang
Copy link
Contributor

@yhtang yhtang commented Apr 24, 2025

Currently, the JAX Python source that we installed in the container is created by build-jax.sh (via python build.py build --editable ... --wheels=jax,jaxlib,jax-cuda-plugin,jax-cuda-pjrt) and placed under /opt/jaxlibs/jax:

$ docker run --gpus all ghcr.io/nvidia/jax:jax python -c 'import jax; print(jax.__path__)'
['/opt/jaxlibs/jax/jax']

As a result, two copies of the JAX source exist in the container at /opt/jax and /opt/jaxlibs/jax, leading to potential confusions. Also, this defeats the 'editable' purpose of the build, as mounting a JAX working copy from outside the container to /opt/jax will not replace the installed JAX package.

Given that the editable wheel at /opt/jaxlibs/jax is practically identical to the original source (sans a number of build-related files), we should bypass creating this extra copy and directly specify /opt/jax as the installation target in the Dockerfile and build script. With this change, we will have

>>> import jax; print(jax.__path__)
['/opt/jax/jax']

which is the desired usage pattern.

@yhtang yhtang requested review from Steboss, gpupuck and olupton April 24, 2025 17:49
@gpupuck
Copy link
Contributor

gpupuck commented Apr 24, 2025

You'll need to change test-jax.sh also

@yhtang
Copy link
Contributor Author

yhtang commented Apr 24, 2025

test-jax.sh

done 😁

@gpupuck gpupuck force-pushed the fix-editable-build branch from c0d8fdf to 5d77674 Compare April 25, 2025 03:10
Copy link
Contributor

@gpupuck gpupuck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pip --disable-pip-version-check install .... -e ${BUILD_PATH_JAXLIB}/jax   <------
...
## after installation (example)
# jax                     0.5.4.dev20250325    /opt/jaxlibs/jax   <--------
# jax-cuda12-pjrt         0.5.4.dev20250325    /opt/jaxlibs/jax_cuda12_pjrt

Install dir and resulting output needs to be updated too

@yhtang
Copy link
Contributor Author

yhtang commented Apr 25, 2025

pip --disable-pip-version-check install .... -e ${BUILD_PATH_JAXLIB}/jax   <------
...
## after installation (example)
# jax                     0.5.4.dev20250325    /opt/jaxlibs/jax   <--------
# jax-cuda12-pjrt         0.5.4.dev20250325    /opt/jaxlibs/jax_cuda12_pjrt

Install dir and resulting output needs to be updated too

Fixed

@yhtang yhtang merged commit fd8ab00 into main Apr 25, 2025
100 of 105 checks passed
@yhtang yhtang deleted the fix-editable-build branch April 25, 2025 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants