Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped convolutions are much slower than standard convolutions when increasing kernel dilation #26266

Open
karunrao97 opened this issue Feb 2, 2025 · 1 comment
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs performance make things lean and fast

Comments

@karunrao97
Copy link

karunrao97 commented Feb 2, 2025

Description

I’ve been experimenting with depthwise grouped convolutions in JAX, and noticed that as I increase the kernel dilation rate, they become much slower than standard convolutions, even though they are performing fewer mathematical operations. I observed this on several different GPUs, including T4, P100 and H100.

The following example demonstrates this issue using 1D convolutions across different kernel sizes and channel sizes. I tried 3 different methods to implement depthwise convolutions:

  1. Use lax.conv_general_dilated(…, feature_group_count=num_channels).
  2. Use vmap to map lax.conv_general_dilated(…, feature_group_count=1) across the out_channels dim of both the input and the kernel.
  3. Loop across slices of the out_channels dim of both the input and the kernel, then call lax.conv_general_dilated(…, feature_group_count=1) on these slices and concatenate the results.
from functools import partial

import jax
import numpy as np
import pandas as pd
import seaborn as sns

from jax import lax
from jax import numpy as jnp
from matplotlib import pyplot as plt

batch_size = 32
num_steps = 2**10

def get_xk(kernel_size: int, num_channels: int, dilation: int, is_depthwise: bool):
    fn = lambda shape: jnp.arange(size := int(np.prod(shape)), dtype=jnp.float32).reshape(shape) / size
    x = fn((batch_size, num_steps + dilation * (kernel_size - 1), num_channels))
    k = fn((kernel_size, 1 if is_depthwise else num_channels, num_channels))
    return x, k

conv = partial(
    lax.conv_general_dilated, 
    window_strides=(1,),
    padding="VALID",
    dimension_numbers=("NTC", "TIO", "NTC"),
)

def standard(x, k, dilation: int):
    return conv(x, k, rhs_dilation=(dilation,))

def depthwise(x, k, dilation: int):
    num_channels = x.shape[-1]
    return conv(x, k, rhs_dilation=(dilation,), feature_group_count=num_channels)

def depthwise_vmap(x, k, dilation: int):
    fn = jax.vmap(partial(conv, rhs_dilation=(dilation,)), in_axes=-2, out_axes=-2)
    return fn(x[..., jnp.newaxis], k[..., jnp.newaxis]).squeeze(axis=-1)

def depthwise_concat(x, k, dilation):
    num_channels = x.shape[-1]
    ys = [conv(x[..., i: i + 1], k[..., i: i + 1], rhs_dilation=(dilation,)) for i in range(num_channels)]
    return jnp.concatenate(ys, axis=-1)

df = []
for kernel_size in [1, 4, 16, 64]:
    for num_channels in [32, 128, 512]:
        for dilation in [2**i for i in range(8)]:
            depthwise_y = None
            for i, fn in enumerate([standard, depthwise, depthwise_vmap, depthwise_concat]):
                fn_name = fn.__name__
                print(f"{fn_name:s}, {kernel_size=}, {num_channels=}, {dilation=}")
                is_depthwise = i > 0
                x, k = get_xk(kernel_size, num_channels, dilation, is_depthwise)
                fn = jax.jit(partial(fn, dilation=dilation))
                y = jax.block_until_ready(fn(x, k))  # compile fn
                assert y.shape == (batch_size, num_steps, num_channels)
                if is_depthwise:
                    assert depthwise_y is None or jnp.allclose(y, depthwise_y)
                    depthwise_y = y
                t = %timeit -o -n 10 jax.block_until_ready(fn(x, k))
                df.append([fn_name, kernel_size, num_channels, dilation, t.average])
            print()
df = pd.DataFrame(df, columns=["conv_fn", "kernel_size", "num_channels", "dilation", "time_s"])

sns.relplot(
    df,
    x="dilation",
    y="time_s",
    row="kernel_size",
    col="num_channels",
    hue="conv_fn",
    style="conv_fn",
    kind="line",
    markers=True,
    facet_kws=dict(sharey=False),
).set(yscale="log")
plt.xscale("log", base=2);

Image

Some observations:

  • With method 1 (feature_group_count=num_channels), depthwise convolutions are often faster than standard convolutions for small dilation rates, but their performance scales poorly as we increase the dilation rate.
  • Method 2 (vmap) has the same performance as method 1 — inspecting the jaxpr, I see that it is indeed calling lax.conv_general_dilated(…, feature_group_count=num_channels) under the hood.
  • Method 3 (concatenate) is slower than the other 2 methods for small dilation rates, but it scales much better as we increase the dilation rate. For large kernel and channel sizes and dilation>1, it is also considerably faster than standard convolutions.
  • Standard convolutions also have a noticeable bump when going from dilation=1 to dilation=2, but aside from that, their performance is similar across higher dilations rates. For large kernel sizes, this bump is particularly notable.
  • Even for kernel_size=1, where I believe dilation has no effect, all depthwise methods show a bump when going from dilation=1 to dilation=2. We can easily work around this by using dilation=1 here, of course, but just thought I'd include this in case it helps in the diagnosis.

I would appreciate it if someone can help shed some light on this issue, or maybe share some creative work-arounds to get more consistent performance from depthwise convolutions with dilation. Thank you for your help!

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.33
jaxlib: 0.4.33
numpy:  1.26.4
python: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='5c396733a675', release='6.6.56+', version='#1 SMP PREEMPT_DYNAMIC Sun Nov 10 10:07:59 UTC 2024', machine='x86_64')


$ nvidia-smi
Sun Feb  2 19:24:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   57C    P0             34W /  250W |     259MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
@karunrao97 karunrao97 added the bug Something isn't working label Feb 2, 2025
@karunrao97
Copy link
Author

I came up with yet another implementation for depthwise convolutions, which just expands the depthwise kernel into a standard kernel and then runs a standard convolution:

def depthwise_sparse(x, k, dilation):
    num_channels = x.shape[-1]
    k = np.eye(num_channels)[np.newaxis] * k
    return conv(x, k, rhs_dilation=(dilation,))

I verified that this produces the same result as the other depthwise methods described above, and it has the same performance as standard convolutions. This still feels less than satisfactory since we don't reap the performance benefits of depthwise convolutions.

@hawkinsp hawkinsp added NVIDIA GPU Issues specific to NVIDIA GPUs performance make things lean and fast labels Feb 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs performance make things lean and fast
Projects
None yet
Development

No branches or pull requests

3 participants