question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Issues (or typos?) when running JAX code with multiple GPUs

See original GitHub issue

Hi there. Got two issues when running the JAX code with multiple GPUs:

  1. https://github.com/deepmind/ferminet/blob/b46077f1a4687b3ac03b583d9726f59ab4d914d9/ferminet/train.py#L293-L297 It would hit too many values to unpack error when num_devices is greater than 1. My understanding is that we should do
key, *subkeys = jax.random.split(key, num_devices+1)

instead (note the extra asterisk), in which case the following explicit broadcast is not necessary any more for single GPU case.

  1. https://github.com/deepmind/ferminet/blob/b46077f1a4687b3ac03b583d9726f59ab4d914d9/ferminet/train.py#L372-L373 constants.pmap gives a tuple of an array instead of just an array in this case when num_devices is greater than 1 (not sure why, probably just JAX’s API). This would cause logging to complain. It’s easy to fix though.

Let me know if it makes sense. Also if you like, I can submit a tiny PR to fix them

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6

github_iconTop GitHub Comments

1reaction
jsspencercommented, Feb 7, 2021

That’s the correct thing to do. pmean is a collective reduce (ie like MPI_AllReduce instead of MPI_Reduce, if you’re familiar with MPI) and the result is sharded all devices. The logging call requires the data on the host, which transfers the array back, resulting in an array of length num_devices, with elements, as you say, indentical due to the pmean. The same thing is done in several places during the main training loop.

0reactions
jsspencercommented, Feb 12, 2021

Fixed in #17. Thanks for spotting this and sending the patch!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Jax multi-gpu randomly hangs forever · Issue #10969 - GitHub
We are facing a problem where a training and validation code based on jax/flax hangs randomly on a multi-gpu host. Using a single...
Read more >
Introduction to porting Python to GPU with JAX. - NERSC
Porting the code​​ Kernels were ported from C++ to Numpy to JAX and validated using unit tests. Kernels loop on irregular intervals, we...
Read more >
Using JAX in multi-host and multi-process environments
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple...
Read more >
Getting started with JAX (MLPs, CNNs & RNNs)
Broadly speaking there are two types of automatic differentiation: ... JAX automatically detects whether you have access to a GPU or TPU.
Read more >
[D] Should We Be Using JAX in 2022? : r/MachineLearning
Will JAX's functional paradigm lead to issues for those without functional experience, especially in Deep Learning?
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found