jax
User documentation
The JAX container is developed by AMD specifically for LUMI and contains the necessary parts to run JAX on LUMI, including the plugin needed for RCCL when doing distributed AI, and a suitable version of ROCm for the version of JAX.
Note that JAX is still very much in development. Moreover, we sometimes have to use newer version of ROCm than the drivers on LUMI support, so there is no guarantee that this container will work for you (even though it did pass some tests we did), and there might be problems that cannot be fixed by the support team. This is software for users with a development spirit, not for users who expect something that simply and always works.
Use via EasyBuild-generated modules
The EasyBuild installation with the EasyConfigs mentioned below will do four things:
-
It will copy the container to your own file space. We realise containers can be big, but it ensures that you have complete control over when a container is removed again.
We will remove a container from the system when it is not sufficiently functional anymore, but the container may still work for you. E.g., after an upgrade of the network drivers on LUMI, the RCCL plugin for the LUMI Slingshot interconnect may be broken, but if you run on only one node PyTorch may still work for you.
If you prefer to use the centrally provided container, you can remove your copy after loading of the module with
rm $SIF
followed by reloading the module. This is however at your own risk. -
It will create a module file. When loading the module, a number of environment variables will be set to help you use the module and to make it easy to swap the module with a different version in your job scripts.
-
SIF
andSIFJAX
both contain the name and full path of the singularity container file. -
SINGULARITY_BIND
will mount all necessary directories from the system, including everything that is needed to access the project, scratch and flash file systems. -
RUNSCRIPTS
andRUNSCRIPTSJAX
contain the full path of the directory containing some sample run scripts that can be used to run software in the container, or as inspiration for your own variants.
-
-
It creates the $RUNSCRIPTS directory with scripts to be run in the container:
conda-python-simple
: This initialises Python in the container and then calls Python with the arguments ofconda-python-simple
. It can be used, e.g., to run commands through Python that utilise a single task but all GPUs.
-
It creates a
bin
directory with scripts to be run outside of the container:-
start-shell
: Serves a double purpose:-
Without further arguments, it will start a shell in the container with the Conda environment used to build the container activated.
-
With arguments it simply runs a shell in the container, but the Conda environment will not be activated.
-
The
bin
directory is not mounted in the container, but if you would, the scripts would recognise this and work or print a message that they cannot be used in that environment. -
The container uses a miniconda environment in which Python and its packages are installed.
That environment needs to be activated in the container when running, which can be done
with the command that is available in the container as the environment variable
WITH_CONDA
(which for this container it is
source /opt/miniconda3/bin/activate jax
).
Example of the use of WITH_CONDA
: Check the Python packages in the container
in an interactive session:
which takes you in the container, and then in the container, at the Singularity>
prompt:
An example of the use of start-shell
that even works on the login nodes is:
module load LUMI jax/0.4.28-rocm-6.2.0-python-3.12-singularity-20241007
start-shell -c '/runscripts/conda-python-simple -c "import numpy ; import scipy ; import jax ; print( f'"'JAX {jax.__version__}, NumPy {numpy.__version__}, SciPy {scipy.__version__}.'"' )"'
The container (when used with SINGULARITY_BIND
of the module) also provides
one or more wrapper scripts to start Python from the
conda environment in the container. Those scripts are also available outside the
container for inspection after loading the module in the
$RUNSCRIPTS
subdirectory and you can use those scripts as a source
of inspiration to develop a script that more directly executes your commands or
does additional initialisations.
Example (in an interactive session):
salloc -N1 -pstandard-g -t 10:00
module load LUMI jax/0.4.28-rocm-6.2.0-python-3.12-singularity-20241007
srun -N1 -n1 --gpus 8 singularity exec $SIF /runscripts/conda-python-simple \
-c 'import jax; print("I have these devices:", jax.devices("gpu"))'
Installation
To install the container with EasyBuild, follow the instructions in the
EasyBuild section of the LUMI documentation, section "Software",
and use the dummy partition container
, e.g.:
module load LUMI partition/container EasyBuild-user
eb jax-0.4.28-rocm-6.2.0-python-3.12-singularity-20241007.eb
To use the container after installation, the EasyBuild-user
module is not needed nor
is the container
partition. The module will be available in all versions of the LUMI stack
and in the CrayEnv
stack
(provided the environment variable EBU_USER_PREFIX
points to the right location).
Direct access (use without the container module)
The jax containers are available in the following subdirectories of /appl/local/containers
:
-
/appl/local/containers/sif-images
: Symbolic link to the latest version of the container for each ROCm version provided. Those links can change without notice! -
/appl/local/containers/tested-containers
: Tested containers provided as a Singulartiy.sif
file and a docker-generated tarball. Containers in this directory are removed quickly when a new version becomes available. -
/appl/local/containers/easybuild-sif-images
: Singularity.sif
images used with the EasyConfigs that we provide. They tend to be available for a longer time than in the other two subdirectories.
If you depend on a particular version of a container, we recommend that you copy the container to
your own file space (e.g., in /project
) as there is no guarantee the specific version will remain
available centrally on the system for as long as you want.
When using the containers without the modules, you will have to take care of the bindings as some system files are needed for, e.g., MPI. The recommended minimal bindings are:
and the bindings you need to access the files you want to use from /scratch
, /flash
and/or /project
.
You can get access to your files on LUMI in the regular location by also using the bindings
Note that the list recommended bindings may change after a system update or between different containers. We do try to keep the EasyBuild recipes for the modules up-to-date though to reflect those changes.
Singularity containers with modules for binding and extras
Install with the EasyBuild-user module in partition/container
:
To access module help after installation use module spider jax/<version>
.
EasyConfig:
-
Contains JAX 0.4.13 with NumPy 1.26.0 and SciPy 1.11.4.
-
Contains JAX 0.4.13 with NumPy 1.26.3 and SciPy 1.12.0.
-
Contains JAX 0.4.28 with NumPy 1.26.4, SciPy 1.14.1 and pandas 2.2.2.
Archived EasyConfigs
The EasyConfigs below are additonal easyconfigs that are not directly available on the system for installation. Users are advised to use the newer ones and these archived ones are unsupported. They are still provided as a source of information should you need this, e.g., to understand the configuration that was used for earlier work on the system.
-
Archived EasyConfigs from LUMI-EasyBuild-contrib - previously user-installable software
-
EasyConfig jax-0.3.20-cpeCray-22.08-rocm5.2.eb, with module jax/0.3.20-cpeCray-22.08-rocm5.2
-
EasyConfig jax-0.3.20-cpeCray-22.12-rocm5.2.eb, with module jax/0.3.20-cpeCray-22.12-rocm5.2
-
EasyConfig jax-0.3.20-cpeCray-23.03-rocm5.2.eb, with module jax/0.3.20-cpeCray-23.03-rocm5.2
-
EasyConfig jax-0.3.20-cpeGNU-22.08-rocm5.2.eb, with module jax/0.3.20-cpeGNU-22.08-rocm5.2
-
EasyConfig jax-0.3.20-cpeGNU-22.12-rocm5.2.eb, with module jax/0.3.20-cpeGNU-22.12-rocm5.2
-
EasyConfig jax-0.4.1-cpeCray-22.08-rocm5.3.eb, with module jax/0.4.1-cpeCray-22.08-rocm5.3
-
EasyConfig jax-0.4.1-cpeGNU-22.08-rocm5.3.eb, with module jax/0.4.1-cpeGNU-22.08-rocm5.3
-