From 518ee4684a3ae38d5e6d8f9ebca5e25ac1389530 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Wed, 15 Jan 2025 21:49:46 -0800 Subject: [PATCH] Optimitze DepthToSpace for mode = DCR, blocksize = 4 and width = 8x --- .../providers/cpu/tensor/space_depth_ops.cc | 144 ++++++++++++++++-- 1 file changed, 132 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc index 7e1049c402210..bfbcbf243ee24 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.cc @@ -142,6 +142,121 @@ Status SpaceToDepth::Compute(OpKernelContext* context) const { return Status::OK(); } +template +static Status DepthSpaceDCRBlock4Width8XOpCpuImpl(OpKernelContext* ctx, + const Tensor& input, Tensor& output, + const int64_t batch, + const int64_t input_depth, + const int64_t input_height, + const int64_t input_width) { + std::vector permutations = {0, 3, 4, 1, 5, 2}; + constexpr int blocksize = 4; + constexpr int internal_rank = 6; + int64_t internal_output_depth = input_depth / blocksize / blocksize; + const TensorShape internal_input_shape = TensorShape{batch, blocksize, blocksize, + internal_output_depth, input_height, input_width}; + const TensorShape internal_output_shape = TensorShape{batch, internal_output_depth, + input_height, blocksize, + input_width, blocksize}; + const int64_t number_of_elements = internal_input_shape.Size(); + const auto& internal_output_dims = internal_output_shape.GetDims(); + + InlinedVector stride(internal_rank); + for (size_t i = 0; i < internal_rank; i++) { + size_t inpdim = permutations[i]; + if (inpdim + 1 < internal_rank) + stride[i] = onnxruntime::narrow(internal_input_shape.SizeFromDimension(inpdim + 1)); + else + stride[i] = 1; + } + + InlinedVector internal_output_stride(internal_rank); + internal_output_stride[internal_rank - 1] = 1; + for (int64_t i = internal_rank - 2; i >= 0; --i) { + internal_output_stride[i] = internal_output_stride[i + 1] * internal_output_dims[i + 1]; + } + + const auto* input_data = reinterpret_cast(input.DataRaw()); + auto* output_data = reinterpret_cast(output.MutableDataRaw()); + + Status status = Status::OK(); + + concurrency::ThreadPool::TryParallelFor( + ctx->GetOperatorThreadPool(), static_cast(number_of_elements), + {static_cast(sizeof(uint8_t)), static_cast(sizeof(uint8_t)), 1.0F}, + [&internal_output_stride, input_data, &stride, output_data](std::ptrdiff_t first, + std::ptrdiff_t last) { + constexpr int chunk_size = 32; + + ORT_ENFORCE((first < last) && (first % chunk_size == 0) && (last % chunk_size == 0)); + + /// The loop is unrolled by 32 for the code below: + /// for (std::ptrdiff_t i = first; i < last; ++i) { + /// int d0 = static_cast(i / internal_output_stride[0]); + /// int d1 = static_cast((i % internal_output_stride[0]) / internal_output_stride[1]); + /// int d2 = static_cast((i % internal_output_stride[1]) / internal_output_stride[2]); + /// int d3 = static_cast((i % internal_output_stride[2]) / internal_output_stride[3]); + /// int d4 = static_cast((i % internal_output_stride[3]) / internal_output_stride[4] /* blocksize = 4 */); + /// int d5 = static_cast(i % internal_output_stride[4] /* blocksize = 4 */); + /// const T* source = input_data + (d0 * stride[0] + + /// d1 * stride[1] + + /// d2 * stride[2] + + /// d3 * stride[3] + + /// d4 * stride[4] /* 1 */ + + /// d5 * stride[5]); + /// T* target = output_data + i; + /// *target = *source; + /// } + for (std::ptrdiff_t i = first; i < last; i += chunk_size) { + int d0 = static_cast(i / internal_output_stride[0]); + int d1 = static_cast((i % internal_output_stride[0]) / internal_output_stride[1]); + int d2 = static_cast((i % internal_output_stride[1]) / internal_output_stride[2]); + int d3 = static_cast((i % internal_output_stride[2]) / internal_output_stride[3]); + int d4 = static_cast((i % internal_output_stride[3]) / 4 /* blocksize = internal_output_stride[4] */); + const T* source = input_data + (d0 * stride[0] + + d1 * stride[1] + + d2 * stride[2] + + d3 * stride[3] + + d4 * 1 /* stride[4] */); + T* target = output_data + i; + *(target + 0) = *(source + ((0 / 4) * 1) + ((0 % 4) * stride[5])); + *(target + 1) = *(source + ((1 / 4) * 1) + ((1 % 4) * stride[5])); + *(target + 2) = *(source + ((2 / 4) * 1) + ((2 % 4) * stride[5])); + *(target + 3) = *(source + ((3 / 4) * 1) + ((3 % 4) * stride[5])); + *(target + 4) = *(source + ((4 / 4) * 1) + ((4 % 4) * stride[5])); + *(target + 5) = *(source + ((5 / 4) * 1) + ((5 % 4) * stride[5])); + *(target + 6) = *(source + ((6 / 4) * 1) + ((6 % 4) * stride[5])); + *(target + 7) = *(source + ((7 / 4) * 1) + ((7 % 4) * stride[5])); + *(target + 8) = *(source + ((8 / 4) * 1) + ((8 % 4) * stride[5])); + *(target + 9) = *(source + ((9 / 4) * 1) + ((9 % 4) * stride[5])); + *(target + 10) = *(source + ((10 / 4) * 1) + ((10 % 4) * stride[5])); + *(target + 11) = *(source + ((11 / 4) * 1) + ((11 % 4) * stride[5])); + *(target + 12) = *(source + ((12 / 4) * 1) + ((12 % 4) * stride[5])); + *(target + 13) = *(source + ((13 / 4) * 1) + ((13 % 4) * stride[5])); + *(target + 14) = *(source + ((14 / 4) * 1) + ((14 % 4) * stride[5])); + *(target + 15) = *(source + ((15 / 4) * 1) + ((15 % 4) * stride[5])); + *(target + 16) = *(source + ((16 / 4) * 1) + ((16 % 4) * stride[5])); + *(target + 17) = *(source + ((17 / 4) * 1) + ((17 % 4) * stride[5])); + *(target + 18) = *(source + ((18 / 4) * 1) + ((18 % 4) * stride[5])); + *(target + 19) = *(source + ((19 / 4) * 1) + ((19 % 4) * stride[5])); + *(target + 20) = *(source + ((20 / 4) * 1) + ((20 % 4) * stride[5])); + *(target + 21) = *(source + ((21 / 4) * 1) + ((21 % 4) * stride[5])); + *(target + 22) = *(source + ((22 / 4) * 1) + ((22 % 4) * stride[5])); + *(target + 23) = *(source + ((23 / 4) * 1) + ((23 % 4) * stride[5])); + *(target + 24) = *(source + ((24 / 4) * 1) + ((24 % 4) * stride[5])); + *(target + 25) = *(source + ((25 / 4) * 1) + ((25 % 4) * stride[5])); + *(target + 26) = *(source + ((26 / 4) * 1) + ((26 % 4) * stride[5])); + *(target + 27) = *(source + ((27 / 4) * 1) + ((27 % 4) * stride[5])); + *(target + 28) = *(source + ((28 / 4) * 1) + ((28 % 4) * stride[5])); + *(target + 29) = *(source + ((29 / 4) * 1) + ((29 % 4) * stride[5])); + *(target + 30) = *(source + ((30 / 4) * 1) + ((30 % 4) * stride[5])); + *(target + 31) = *(source + ((31 / 4) * 1) + ((31 % 4) * stride[5])); + } + }); + + return status; +} + Status DepthToSpace::Compute(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); @@ -199,18 +314,23 @@ Status DepthToSpace::Compute(OpKernelContext* context) const { onnxruntime::narrow(input_width), onnxruntime::narrow(blocksize_)); } else if (input.IsDataType()) { - SpaceDepthOpCpuImpl(input, output, permutation, - onnxruntime::narrow(batch), - onnxruntime::narrow(dim1), - onnxruntime::narrow(blocksize_), - onnxruntime::narrow(dim3), - onnxruntime::narrow(input_height), - onnxruntime::narrow(input_width), - onnxruntime::narrow(input_depth / blocksize_ / blocksize_), - onnxruntime::narrow(input_height), - onnxruntime::narrow(blocksize_), - onnxruntime::narrow(input_width), - onnxruntime::narrow(blocksize_)); + if (is_dcr_ && (blocksize_ == 4) && (input_width % 8 == 0)) { + ORT_RETURN_IF_ERROR(DepthSpaceDCRBlock4Width8XOpCpuImpl(context, input, output, + batch, input_depth, input_height, input_width)); + } else { + SpaceDepthOpCpuImpl(input, output, permutation, + onnxruntime::narrow(batch), + onnxruntime::narrow(dim1), + onnxruntime::narrow(blocksize_), + onnxruntime::narrow(dim3), + onnxruntime::narrow(input_height), + onnxruntime::narrow(input_width), + onnxruntime::narrow(input_depth / blocksize_ / blocksize_), + onnxruntime::narrow(input_height), + onnxruntime::narrow(blocksize_), + onnxruntime::narrow(input_width), + onnxruntime::narrow(blocksize_)); + } } else { // user will not see this as the kernel doesn't claim support for types other than float and double return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported input type in DepthToSpace op: ", input.DataType());