Grouped convolutions are much slower than standard convolutions when increasing kernel dilation #26266
Labels
bug
Something isn't working
NVIDIA GPU
Issues specific to NVIDIA GPUs
performance
make things lean and fast
Description
I’ve been experimenting with depthwise grouped convolutions in JAX, and noticed that as I increase the kernel dilation rate, they become much slower than standard convolutions, even though they are performing fewer mathematical operations. I observed this on several different GPUs, including T4, P100 and H100.
The following example demonstrates this issue using 1D convolutions across different kernel sizes and channel sizes. I tried 3 different methods to implement depthwise convolutions:
lax.conv_general_dilated(…, feature_group_count=num_channels)
.vmap
to maplax.conv_general_dilated(…, feature_group_count=1)
across theout_channels
dim of both the input and the kernel.out_channels
dim of both the input and the kernel, then calllax.conv_general_dilated(…, feature_group_count=1)
on these slices and concatenate the results.Some observations:
feature_group_count=num_channels
), depthwise convolutions are often faster than standard convolutions for small dilation rates, but their performance scales poorly as we increase the dilation rate.vmap
) has the same performance as method 1 — inspecting thejaxpr
, I see that it is indeed callinglax.conv_general_dilated(…, feature_group_count=num_channels)
under the hood.concatenate
) is slower than the other 2 methods for small dilation rates, but it scales much better as we increase the dilation rate. For large kernel and channel sizes anddilation>1
, it is also considerably faster than standard convolutions.dilation=1
todilation=2
, but aside from that, their performance is similar across higher dilations rates. For large kernel sizes, this bump is particularly notable.kernel_size=1
, where I believe dilation has no effect, all depthwise methods show a bump when going fromdilation=1
todilation=2
. We can easily work around this by usingdilation=1
here, of course, but just thought I'd include this in case it helps in the diagnosis.I would appreciate it if someone can help shed some light on this issue, or maybe share some creative work-arounds to get more consistent performance from depthwise convolutions with dilation. Thank you for your help!
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: