GPU Memory allocation quirks and documentation clarity? #31906
Unanswered
cjchristopher
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Based on https://docs.jax.dev/en/latest/gpu_memory_allocation.html I was led to believe that setting
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
was enough to enable full usage of GPU memory - i.e. no need to specify the corresponding
MEM_FRACTION
environment parameter since it would appear to not be relevant if preallocation is disabled - quoting from the above documentation page:Which seems to indicate that it's only IF preallocation is enabled. (As a separate note, referring to
jax/jaxlib/xla_client.py
Lines 174 to 182 in 5fd7788
XLA_CLIENT_MEM_FRACTION
, but I'd expect the same semantics unless told otherwise)However, if you do disable preallocation you are actually still being confined to 75% of your GPUs memory, and both the MEM_FRACTION flags actually do still work if preallocation is disabled (e.g.
jax.devices[0].memory_stats()["bytes_limit"]
will show a figure that ~75% of the GPU memory available)The reason? XLA itself sets 0.75 itself as a default fraction unless overridden: https://github.com/openxla/xla/blob/68163dd3407ba791905277a2c6262933e3ccbb75/xla/pjrt/plugin/xla_gpu/xla_gpu_allocator_config.h#L40
I've discovered that using MEM_FRACTION="1" allows me to actually use GPUs to their full potential - is this something that should be clarified in the documentation? It's not clear at all otherwise that you are capped to 75% memory regardless of whether preallocation is enabled or not.
Beta Was this translation helpful? Give feedback.
All reactions