Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[XLA:GPU] Remove the restrictions that prevent us from fusing the sub…
…channel dequantisation sequence from Triton tiling propagation. There are two cases for the broadcast: 1) Channel quantisation case: a) we have 2d weights + 1d scales + 2d activations. b) In triton prolog we prepare the corresponding block pointers c) inside the for loop along the k-dim every time we load the same 1d tile for scalers, expand it to 2d [block_m,1], and broadcast to block_k elements along newly added dim to [block_m, block_k]. d) then do the multiply and dot 2) Subchannel quantisation case: a) we have 2d weights [M,K] + 2d scales [M,K/q] + 2d activations where q is the subchannel size. b) In triton prolog we prepare the corresponding block pointers c) inside the for loop along the k-dim every time we load the 2d [M,1] tile for scalers and broadcast to block_k elements along the k dim to [block_m, block_k]. d) then do the rest I.e. the difference is that the scalers matrix is the 2d matrix from the very beginning but it is smaller along the k dim and we need to advance it along k dim only by one column instead of block_k columns. It is already 2d, so, we don't need to add the dimension. We could emit the right code if we know that there was the subchannel broadcast and and we know the size of the subchannel. We do this analysis in the triton_tiling_propagation by detecting the broadcast with the follow up bitcast combination like [B,c,M] -> [B,c,q,M] -> [B,c*q,M]. I.e. we do the broadcast but the follow up bitcast merges the broadcasted dim with another nonempty dim. This schema works for the cases when block_k == subchannel_size and for the case when we have split_k == 1. These two restrictions could be addressed in the follow up cls. PiperOrigin-RevId: 725211140
- Loading branch information