You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Note the magnitude of the relative error: (p90=6.266 !!). This happens on my RTX A4500 Laptop GPU (driver 560) and on my V100 (but here I use tensorrt:24.06-py3, as TensorRT 10.7 does not support Volta anymore). The FP16/BF16 case is even worse.
When I do the same conversion with --fp8, the error vanishes (note that the A4500 and V100 do not support FP8 kernels). I compared the trtexec verbose logs, and found that in the fp32 case, TensorRT recognizes the self-attention pattern, but in the FP8 case it does not:
trtexec --onnx=ViT-SO400M-14-SigLIP-384.onnx --verbose
[...]
[01/18/2025-15:16:59] [V] [TRT] Found /visual/trunk/blocks/blocks.18/attn/MatMul to be part of self-attention pattern.
[01/18/2025-15:16:59] [V] [TRT] Found /visual/trunk/blocks/blocks.18/attn/Softmax to be part of self-attention pattern.
[01/18/2025-15:16:59] [V] [TRT] Found /visual/trunk/blocks/blocks.18/attn/MatMul_1 to be part of self-attention pattern.
[01/18/2025-15:16:59] [V] [TRT] Found and reassigned Myelin backends for Self-Attention nodes
[...]
This observation got me thinking...when I replace the /attn/Softmax nodes with a custom TensorRT softmax plugin, the TensorRT optimizer can no longer do the self-attention optimization, and the result is that I get TensorRT engines with acceptable accuracy (even in fp16).
My conclusion: Somehow, for this model, the myelin self-attenion fusion is buggy.
The text was updated successfully, but these errors were encountered:
I am trying to convert an open-clip (
pip install open_clip_torch==2.30.0
) model to TensorRT:This produces a valid onnx file, such that onnx-runtime execution matches with pytorch of the original model.
To convert the model to TensorRT, I do:
Note the magnitude of the relative error: (
p90=6.266
!!). This happens on my RTX A4500 Laptop GPU (driver 560) and on my V100 (but here I use tensorrt:24.06-py3, as TensorRT 10.7 does not support Volta anymore). The FP16/BF16 case is even worse.When I do the same conversion with
--fp8
, the error vanishes (note that the A4500 and V100 do not support FP8 kernels). I compared thetrtexec
verbose logs, and found that in the fp32 case, TensorRT recognizes the self-attention pattern, but in the FP8 case it does not:This observation got me thinking...when I replace the
/attn/Softmax
nodes with a custom TensorRT softmax plugin, the TensorRT optimizer can no longer do the self-attention optimization, and the result is that I get TensorRT engines with acceptable accuracy (even in fp16).My conclusion: Somehow, for this model, the myelin self-attenion fusion is buggy.
The text was updated successfully, but these errors were encountered: