Skip to content

Commit

Permalink
Fixed kernel launch config for permute021 (#957)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #957

make sure we have grid_z to be less than 65535.

Reviewed By: muchulee8

Differential Revision: D50517471

fbshipit-source-id: bc33403f6955793380e1babc934087c79ee83784
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Oct 22, 2023
1 parent 3c4ba48 commit 0c38a45
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 16 deletions.
96 changes: 80 additions & 16 deletions python/aitemplate/backend/common/tensor/permute021_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,36 @@
{{tensor_accessor_libs}}
template <typename T>
// blockIdx.x -> ni
// blockIdx.y -> hwi
// blockIdx.z -> ci
__device__ __forceinline__ void block_fn_nhc(int32_t& ni, int32_t& hwi, int32_t& ci) {
ni = blockIdx.x;
hwi = blockIdx.y;
ci = blockIdx.z;
}
// blockIdx.x -> ni
// blockIdx.y -> ci
// blockIdx.z -> hwi
__device__ __forceinline__ void block_fn_nch(int32_t& ni, int32_t& hwi, int32_t& ci) {
ni = blockIdx.x;
ci = blockIdx.y;
hwi = blockIdx.z;
}
// blockIdx.x -> ci
// blockIdx.y -> hwi
// blockIdx.z -> ni
__device__ __forceinline__ void block_fn_chn(int32_t& ni, int32_t& hwi, int32_t& ci) {
ci = blockIdx.x;
hwi = blockIdx.y;
ni = blockIdx.z;
}
using BlockFunc = void (*)(int32_t&, int32_t&, int32_t&);
template <typename T, BlockFunc BLOCK_FN>
__global__ void permute021_kernel(T *output,
const T *input,
const int64_t n,
Expand All @@ -101,9 +130,11 @@
const int32_t tid = threadIdx.y * blockDim.x + threadIdx.x;
const int32_t wid = tid / TILE_SIZE;
const int32_t lid = tid % TILE_SIZE;
const int32_t ni = blockIdx.z;
const int32_t hwi0 = blockIdx.y * TILE_SIZE;
const int32_t ci0 = blockIdx.x * TILE_SIZE;
int32_t ni_tmp, hwi_tmp, ci_tmp;
BLOCK_FN(ni_tmp, hwi_tmp, ci_tmp);
const int32_t ni = ni_tmp;
const int32_t hwi0 = hwi_tmp * TILE_SIZE;
const int32_t ci0 = ci_tmp * TILE_SIZE;
size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0;
Expand Down Expand Up @@ -172,21 +203,54 @@
const int32_t x_dim1 = x_dims[rank-2];
const int32_t x_dim2 = x_dims[rank-1];
const int64_t n = x_dim0;
#define THROW_INVALID_LAUNCH_CONFIG \
throw std::runtime_error( \
std::string("invalid cuda launch config: ") + \
std::to_string(grid_c) + ", " + \
std::to_string(grid_hw) + ", " + \
std::to_string(grid_n));
const int32_t n = static_cast<int32_t>(x_dim0);
const int32_t h = 1;
const int32_t w = x_dim1;
const int32_t c = x_dim2;
dim3 grid((c + TILE_SIZE - 1) / TILE_SIZE, (h * w + TILE_SIZE - 1) / TILE_SIZE, n);
dim3 block(TILE_SIZE, TILE_SIZE / CH_K);
permute021_kernel<{{lib_dtype}}><<<grid, block, 0, stream>>>(
static_cast<{{lib_dtype}}*>(out_ptr),
static_cast<const {{lib_dtype}}*>(in_ptr),
n,
h,
w,
c,
input_accessor
);
const int32_t grid_c = (c + TILE_SIZE - 1) / TILE_SIZE;
const int32_t grid_hw = (h * w + TILE_SIZE - 1) / TILE_SIZE;
const int32_t grid_n = n;
constexpr int32_t max_grid_z = 65535;
constexpr int32_t max_grid_x = 2147483647;
if (grid_c > max_grid_x || grid_hw > max_grid_x || grid_n > max_grid_x) {
THROW_INVALID_LAUNCH_CONFIG
}
if ((grid_c <= max_grid_z && grid_hw <= max_grid_z && grid_n <= max_grid_z) ||
(grid_c > max_grid_z && grid_hw <= max_grid_z && grid_n <= max_grid_z)) {
dim3 grid(grid_c, grid_hw, grid_n);
dim3 block(TILE_SIZE, TILE_SIZE / CH_K);
permute021_kernel<{{lib_dtype}}, block_fn_chn><<<grid, block, 0, stream>>>(
static_cast<{{lib_dtype}}*>(out_ptr),
static_cast<const {{lib_dtype}}*>(in_ptr),
n, h, w, c, input_accessor
);
} else if (grid_n > max_grid_z && grid_hw <= max_grid_z && grid_c <= max_grid_z) {
dim3 grid(grid_n, grid_c, grid_hw);
dim3 block(TILE_SIZE, TILE_SIZE / CH_K);
permute021_kernel<{{lib_dtype}}, block_fn_nch><<<grid, block, 0, stream>>>(
static_cast<{{lib_dtype}}*>(out_ptr),
static_cast<const {{lib_dtype}}*>(in_ptr),
n, h, w, c, input_accessor
);
} else if (grid_n > max_grid_z && grid_hw <= max_grid_z && grid_c <= max_grid_z) {
dim3 grid(grid_n, grid_hw, grid_c);
dim3 block(TILE_SIZE, TILE_SIZE / CH_K);
permute021_kernel<{{lib_dtype}}, block_fn_nhc><<<grid, block, 0, stream>>>(
static_cast<{{lib_dtype}}*>(out_ptr),
static_cast<const {{lib_dtype}}*>(in_ptr),
n, h, w, c, input_accessor
);
} else {
THROW_INVALID_LAUNCH_CONFIG
}
}
} // namespace
Expand Down
3 changes: 3 additions & 0 deletions tests/unittest/ops/test_permute021.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def _test_permute_021(
param(3, (2, 3, 4, 384, 262), (0, 1, 2, 4, 3)),
param(4, (IntVar([2, 3]), 384, 262), (0, 2, 1)),
param(5, (IntVar([2, 3, 4]), 5, 384, 262), (0, 1, 3, 2)),
param(6, (409600, 12, 16), (0, 2, 1)),
param(7, (12, 409600, 16), (0, 2, 1)),
param(8, (12, 16, 409600), (0, 2, 1)),
]
)
def test_permute021_fp16(self, id, input_shape, dims):
Expand Down

0 comments on commit 0c38a45

Please sign in to comment.