Compiler passes for OpenXLA for finite element methods #28789
connorjward
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am interested in using OpenXLA (via JAX) for performing finite element computations. As far as I can tell the necessary interface already exists, but one critical compiler transformation appears to be missing.
To be more specific, I would like to be able to write code like the following:
The crucial transformation that doesn't appear to be happening is a "tile and fuse" transformation. Since the axis
nappears in all steps of the computation it should be possible to tile the iteration to compute, say, 10 rows at once. As things stand this doesn't happen andinput_packedis fully materialised as an enormous temporary which is not performant.I have found this previous unsolved issue which I believe describes a very similar problem.
I would love to hear your feedback regarding whether you think it is reasonable to expect OpenXLA to perform this transformation and what would be necessary to enable it. I would potentially be able to have a crack myself some time in the future.
Some additional points to emphasise include:
Beta Was this translation helpful? Give feedback.
All reactions