Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fused GaLore Adam (WIP)
Various fused implementations of
Adam
update step per Gradient Low-Rank ProjectionThis is an initial attempt at optimizing the update step of the
GaLore Adam
optimizer.Overview
The
GaLore
Adam
optimizer introduces additional ops to the traditionaladam
update step.Specifically:
grad
is projected to low rank --> additional matmuladam
states are updated withgrad
elementwise (same asAdam
except in low-rank)grad
is projected to full rank --> additional matmulparams
are updated with the normalized full rank gradImplementation
Various fusions were attempted across 2 kernel implementations:
Fused
adam
state updates are loaded and updated (inplace) during the firstmatmul
matmul
Hybrid
torch matmul
(i.e.,cuBlas
)Fused
Performance
Below are benchmarks for various kernels:
torch
- referencetorch
implementation where each of the steps are implemented verbatim per abovehybrid
- see abovefused
- see abovecompiled
-torch
reference implementation compiled usingtorch.compile
withfullgraph=True
andmode="max-autotune"
.Configs for each benchmark are the
grad (param)
shape,dtype
ofgrad
andadam
states, andallow_tf32
, whethertorch
andtriton
matmuls are allowed to useTF32
tensor cores (seeDiscussion
).Grad shape
:4096x4096
,dtype
:torch.float32
,allow_tf32
:False
Grad shape
:4096x4096
,dtype
:torch.float32
,allow_tf32
:True
Grad shape
:4096x11008
,dtype
:torch.float32
,allow_tf32
:False
Grad shape
:4096x11008
,dtype
:torch.float32
,allow_tf32
:True
Accuracy
Comparison to reference
torch
implementation:Discussion
Down Projection GEMM Shape
The motivation for the
hybrid
approach is the unconventional matrix shapes of the down projection (Step 1):grad
matrix is maintained while other is projected to low rank per theGaLore
algorithmM >= N
, the GEMM is of shape (M x N
) x (N x rank
) = (M x rank
), (rank x M
) x (M x N
) = (rank x N
) otherwise{M, N} >> rank
by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny)split-k / parallel reduction
GEMM
paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension.triton
autotuner
for the down projection step, despite tuning across many compute and io-bound configs (seefused.triton_utils.kernels.matmul.py
).triton
-tunedmatmul
against defaulttorch.matmul
for these shapes showed worse performance, fortorch.float32
Effect of
TF32
tensor coresallow_tf32
: this has significant impact on relative performance oftriton
vstorch
matmuls:matmul
show that:allow_tf32=True
for both, triton exhibits~1.30x
performance improvement overtorch
.allow_tf32=False
, performance oftriton
degrades significantly to~.50x
oftorch
.See this
torch note
for more details on this feature.Note: This might be less of a concern given this incoming triton PR, which implements a fast
TF32
trick that improves both performance and accuracy.Repro
tests/test_fused_kernels.py
is aCLI
that has 2 modes, one for testing kernel accuracy, and the other for benchmarking across a number of configs.Examples
Accuracy
Test accuracy of
torch
vshybrid
forM=4096
,N=4096
,rank=128
, andtf32
switched on:Benchmark
Benchmark across all kernels without
tf32
:Additional options
Note: Passing in the additional flag
--verbose
will showtriton
autotuning logs -- I customized thetriton
autotuner spit out configs and other details.Test Env
NVIDIA RTX A6000
86
48676MB
84
2.3.0.dev20240310+cu118
3.0.0
Next Steps
FusedGaLoreOptimizer
Cutlass
- given fixed GEMM shape, experiment withCutlass
GEMMs (split-k
,stream-k
, fasttensorops
). Interestingly, profilingtorch.matmul
for down projection shows thatcuBlas
dispatches to aCutlass
kernel of shape128x128x16
.AdamW8bit
torch.compile
performance