You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Seems to be NPU specific, as the same model works on CPU and GPU.
Tested shapes:
shape: [4 - 396] x 23 x 40 - works
shape: [400 - ...] x 23 x 40 - fails
shape: [4 - 1024] x 24 x 40 - works
shape: [4 - ...] x 25 x 40 - works
shape: 1024 x [5, 7, 10, 11, 13-19, 21-23, 25-27, 29-31] x 40 - fails
Step-by-step reproduction
Sample code to generate a small model with a convolution and a depth to space layer. Depth to space layer is changed to blocks_first (DCR) mode and executed. Model compilation fails for some shapes, for example 1024 x 23 x 40.
importonnximporttorchimportopenvinoasovimporttorch.nnasnnclassMyModel(nn.Module):
def__init__(self, hidden_channels=128):
super(MyModel, self).__init__()
self.conv=nn.Conv2d(8, hidden_channels, kernel_size=3, padding=1)
self.pixel_shuffle=nn.PixelShuffle(upscale_factor=2)
defforward(self, x):
returnself.pixel_shuffle(self.conv(x))
deftest_model(hidden_channels, height, width):
model=MyModel(hidden_channels=hidden_channels).eval()
# ONNX exportdummy_input=torch.randn(1, 8, height, width)
onnx_path="model.onnx"torch.onnx.export(
model,
dummy_input,
onnx_path,
input_names=["input"],
output_names=["output"],
opset_version=11,
)
# Convert DepthToSpace from mode CRD to DCR# For testing only, this will change the output of the modelm=onnx.load(onnx_path)
fornodeinm.graph.node:
ifnode.op_type=="DepthToSpace":
forattrinnode.attribute:
ifattr.name=="mode"andattr.s==b"CRD":
attr.s=b"DCR"onnx.save(m, onnx_path)
# Export to OpenVINOov_model=ov.convert_model(onnx_path)
ov.save_model(ov_model, "model.xml")
# Test OpenVINO modelcore=ov.Core()
compiled=core.compile_model(core.read_model("model.xml"), "NPU")
result=compiled([dummy_input.numpy()])[compiled.output(0)]
print(f"Channels: {hidden_channels}, output shape: {result.shape}")
if__name__=="__main__":
test_model(1024, 23, 40)
Relevant log output
[ERROR] 13:12:11.629 [vpux-compiler] Got Diagnostic at loc(fused<{name = "main", type = "Func"}>["main"]) : TilingStrategyAssignment Pass failed : Cannot get per cluster memory shapes. Shape [1, 1024, 1, 80], Unsupported distribution: #VPU.DistributionInfo<mode = <SEGMENTED>, num_tiles = [1, 1, 4, 1], num_clusters = 4 : i64, uniform_distributed_segments = unit>
loc(fused<{name = "main", type = "Func"}>["main"]): error: TilingStrategyAssignment Pass failed : Cannot get per cluster memory shapes. Shape [1, 1024, 1, 80], Unsupported distribution: #VPU.DistributionInfo<mode = <SEGMENTED>, num_tiles = [1, 1, 4, 1], num_clusters = 4 : i64, uniform_distributed_segments = unit>
[ERROR] 13:12:11.670 [vpux-compiler] Failed Pass TilingStrategyAssignment on Operation loc(fused<{name = "main", type = "Func"}>["main"])
Traceback (most recent call last):
File "test_openvino.py", line 62, in<module>
test_model(1024, 23, 80)
File "test_openvino.py", line 48, in test_model
compiled = core.compile_model(core.read_model("model.xml"), "NPU")
File "site-packages\openvino\_ov_api.py", line 597, in compile_model
super().compile_model(model, device_name, {} if config is None else config),
RuntimeError: Exception from src\inference\src\cpp\core.cpp:112:
Exception from src\inference\src\dev\plugin.cpp:53:
Exception from src\plugins\intel_npu\src\plugin\src\plugin.cpp:750:
Exception from src\plugins\intel_npu\src\compiler_adapter\src\ze_graph_ext_wrappers.cpp:361:
L0 pfnCreate2 result: ZE_RESULT_ERROR_INVALID_ARGUMENT, code 0x78000004 - generic error code for invalid arguments . Compilation failed
Issue submission checklist
I'm reporting an issue. It's not a question.
I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
There is reproducer code and related data files such as images, videos, models, etc.
The text was updated successfully, but these errors were encountered:
OpenVINO Version
2025.0.0
Operating System
Windows System (NPU driver: 32.0.100.3714)
Device used for inference
NPU (Core Ultra 9 288V)
Framework
None
Model used
No response
Issue description
Some input shapes cause DepthToSpace operation to fail on NPU with error if DepthToSpace is in blocks_first mode:
[ERROR] 13:15:50.502 [vpux-compiler] Got Diagnostic at loc(fused<{name = "main", type = "Func"}>["main"]) : TilingStrategyAssignment Pass failed : Cannot get per cluster memory shapes. Shape [1, 1024, 1, 80], Unsupported distribution: #VPU.DistributionInfo<mode = <SEGMENTED>, num_tiles = [1, 1, 4, 1], num_clusters = 4 : i64, uniform_distributed_segments = unit>
Seems to be NPU specific, as the same model works on CPU and GPU.
Tested shapes:
Step-by-step reproduction
Sample code to generate a small model with a convolution and a depth to space layer. Depth to space layer is changed to blocks_first (DCR) mode and executed. Model compilation fails for some shapes, for example 1024 x 23 x 40.
Relevant log output
Issue submission checklist
The text was updated successfully, but these errors were encountered: