vjp3 and sharding #31958
Unanswered
PhilipVinc
asked this question in
Q&A
vjp3 and sharding
#31958
Replies: 2 comments 4 replies
-
Maybe unreduced can help here? I know it's a bit cryptic but I am working on a way to return unreduced arrays out of minibatches and then do a global reduction out of the scan loop. I think I am pretty close to getting it to work which might help with the problem you pointed out? |
Beta Was this translation helpful? Give feedback.
3 replies
-
Small question, jax.sharding.set_mesh is not documented, is it a "decorator" for general sharding option ? (in addendum to auto sharding mode) NB: I'm trying to gather more information for auto sharding cf: #32494 |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
@yashk2810 pointed out to me that you are working a new variant of
vjp
that would allow easier accumulation over scan loops, dubbedvjp3
.I was this test case which is quite clear
https://cs.opensource.google/jax/jax/+/main:tests/mutable_array_test.py;l=781-810?q=mutable_array_test&ss=jax%2Fjax
This is one very common use case for us in sci-ml and would remove the need for our
batched_vjp
implementation which is very messy.However one common requirement we have is to shard the input
XS
along multiple devices.The 'smart' implementation of this would be to compute a vjp on every device, and have a single global reduction at the end.
I started with a naif implementation, below, which does a global reduction at every scan iteration, which is suboptimal.
Then I tried to hide everything under a shard map (see https://gist.github.com/PhilipVinc/e34a5b5c5e2cc46565c0f74ba0c8491f ) but I suspect that I cannot create an array_ref inside of a shard map, because I get an undefined memory error...
Would be happy to have some insights!
Beta Was this translation helpful? Give feedback.
All reactions