Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into snnn/update_android
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 17, 2025
2 parents aa081d5 + b8599b7 commit 79ce02f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/cpu/nn/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ class BatchNorm : public OpKernel {
const TensorShape& x_shape = X->Shape();
Tensor* Y = p_op_kernel_context->Output(0, x_shape);

// X shape is [N, C, D1, D2, ... Dn], but it can also be 1-D according to onnx spec:
// "The op also accepts single dimension input of size N in which case C is assumed to be 1"
const auto& dims_vec = x_shape.GetDims();
const size_t N = onnxruntime::narrow<size_t>(dims_vec[0]);
const size_t C = onnxruntime::narrow<size_t>(dims_vec[1]); // assume NCHW as per the spec
const size_t C = dims_vec.size() == 1 ? 1 : onnxruntime::narrow<size_t>(dims_vec[1]);

// calculate sample_size (per individual channel)
size_t sample_size = 1;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cpu/nn/batch_norm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class BatchNormHelper {
// NHWC dependent shape: X
// All other shapes are assumed to be in NCHW layout?
const auto& x_dims = X->Shape().GetDims();
if (x_dims.size() < 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input X: NumDimensions() < 1");
}

// If x_dims size < 2, num_channels defaults to 1.
int64_t num_channels;
Expand Down

0 comments on commit 79ce02f

Please sign in to comment.