Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA does too many un-fused transposes #16914

Open
ywrt opened this issue Sep 9, 2024 · 4 comments
Open

XLA does too many un-fused transposes #16914

ywrt opened this issue Sep 9, 2024 · 4 comments

Comments

@ywrt
Copy link

ywrt commented Sep 9, 2024

(This is running on a Nvidia 4090 GPU, with jax '0.4.31')

I had got that is something like the example below. Here, the depth-wise convolution wants the input to be transposed from [batch, sequence, feature] into [batch, feature, sequence] so that it can apply the convolution along sequence.

The output from the convolution is used 3 times, and XLA generates at least 3 seperate (fused) transposes, each of which does a full read and write of memory. This is very slow and causes sadness.

Unfortunately, this example code doesn't reproduce the problem: The problem seems to be quite sensitive to the surrounding code, and trying to trim it down make most of the issue go away. A screen-grab from the profiled code somewhat shows the issue:

Screenshot from 2024-09-09 11-51-25

After the convolution is a loop-transpose_fusion, and then after the two cutlass gemm kernels, there are two input_transpose_fusion kernels, and then following the kernel__1 is another input_transpose_fusion.
Each of these fusion is doing a full read/write of memory.

My main question is: How can I effectively debug this?
eg. Is there a way to log all the GPU kernel calls along with the argument shapes?
Is there some way to see why the transposes didn't fuse into a single kernel with 1 input and 3 outputs?

class Griffin(nn.Module):
  hidden: int = 512
  window: int = 16
  @nn.compact
  def __call__(self, x):
    left = nn.Dense(self.hidden)(x)
    left = jax.nn.silu(left)

    right = nn.Dense(self.hidden)(x)
    right = nn.Conv(right.shape[-1],                
                kernel_size=(self.window,),
                padding='CAUSAL',
                feature_group_count=right.shape[-1],
            )(right)
    
     # Input gate
    gate = nn.Dense(right.shape[-1])(right)
    gate = jax.nn.sigmoid(gate)
    gated = right * gate
    # Generate decay rate.
    decay = nn.Dense(right.shape[-1])(right)
    decay = jnp.exp(-8 * jax.nn.sigmoid(decay))

    right = kernel_recur(gated, decay) # Apply linear recurrence along axis=1

    o = left * right
    x += nn.Dense(x.shape[-1])(o)
    return x.mean()

net = Griffin()
params = net.init(jax.random.key(0), jnp.zeros((64, 1024, 256)))['params']

o = jax.jit(jax.value_and_grad(net.apply))({'params': params}, jnp.zeros((64, 1024, 256)))

module_0071.jit_apply.sm_8.9_gpu_after_optimizations.txt

@cheshire
Copy link
Contributor

eg. Is there a way to log all the GPU kernel calls along with the argument shapes?

Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.

My main question is: How can I effectively debug this?

You could start by dumping HLO after every pass with --xla_dump_hlo_pass_re=.* and figuring out what pass makes bad decisions, or bisecting that somehow.

@ywrt
Copy link
Author

ywrt commented Sep 10, 2024

Yes, you could use logging in *thunks files, but I'm not sure it will help you, as at that point as you've pointed out fusion decisions have been done.

Would you know how I can do that? In the first instance, I'm looking to be able to clearly see that the same memory is being read multiple times by different kernel calls. Is the some flag for the runtime that will have it (eg) log stream executions along with argument shapes?

You could start by dumping HLO after every pass with --xla_dump_hlo_pass_re=.* and figuring out what pass makes bad decisions, or bisecting that somehow.

Yes, I've been looking over this, but it's of lot of manual effort matching up the passes due to renaming et al (oh, and my limited familiarity! ).

Is there any documentation (beyond the source code) for things like the gpu_after_optimizations-buffer-assignment.txt file?

I guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this? If not, I guess I'll try to put together a script to generate a dot graph.

@cheshire
Copy link
Contributor

Would you know how I can do that

You could look at various *thunk files, and inside the Run method they all know the buffers they are acting on, and you could either reuse existing VLOG lines or add your ones.

guess I'm trying to see the memory read/write graph at the CUDA kernel boundaries. Is there some existing way to see this?

Yes, that would be buffer assignment. You'll only see offsets there (and you can match them with instruction names in after-optimizations HLO) as the actual memory is allocated at runtime.

@nullhook
Copy link

fwiw, you can see how to parse the buffer assignment format here: https://github.com/openxla/xla/blob/main/xla/tools/driver.cc#L311

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants