Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build engine failure of TensorRT 10.7 when running quantization model on GPU NVIDIA GeForce RTX 3090 #4320

Open
DuanHongxuan opened this issue Jan 14, 2025 · 3 comments

Comments

@DuanHongxuan
Copy link

DuanHongxuan commented Jan 14, 2025

Description

I tried to quantize an FP32 ONNX model to an INT8 TRT model.
And when I use tensorrt's python api to convert this onnx model to trt engine, I got error like:

[01/14/2025-16:56:50] [TRT] [I] [MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 115, GPU 1322 (MiB)
[01/14/2025-16:56:51] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +2073, GPU +385, now: CPU 2343, GPU 1707 (MiB)
Parsing ONNX file: ./eloftr_outdoor_ds1_fp32_640x640_640x640_v17.onnx
Building an engine from file ./eloftr_outdoor_ds1_fp32_640x640_640x640_v17.onnx, this may take a while...
Loaded 300 images for calibration
Building INT8 engine...
[01/14/2025-16:56:52] [TRT] [W] /model/fine_matching/Reshape_12: IShuffleLayer with zeroIsPlaceHolder=true has reshape dimension at position 1 that might or might not be zero. TensorRT resolves it at runtime, but this may cause excessive memory consumption and is usually a sign of a bug in the network.
[01/14/2025-16:56:52] [TRT] [W] /model/fine_matching/Reshape_12: IShuffleLayer with zeroIsPlaceHolder=true has reshape dimension at position 1 that might or might not be zero. TensorRT resolves it at runtime, but this may cause excessive memory consumption and is usually a sign of a bug in the network.
[01/14/2025-16:56:52] [TRT] [I] Perform graph optimization on calibration graph.
[01/14/2025-16:56:52] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[01/14/2025-16:56:54] [TRT] [I] [GraphReduction] The approximate region cut reduction algorithm is called.
[01/14/2025-16:56:54] [TRT] [I] Detected 2 inputs and 3 output network tensors.
[01/14/2025-16:56:57] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 4: Shape Error (broadcast dimensions must be conformable)
Failed to build serialized engine
Failed to build engine

Environment

TensorRT Version: 10.7

NVIDIA GPU: NVIDIA GeForce RTX 3090

NVIDIA Driver Version: 535.183.01

CUDA Version: 12.1

CUDNN Version: 8.9.5

Operating System: Ubuntu 20.04.6 LTS

Python Version (if applicable): 3.10

Tensorflow Version (if applicable):

PyTorch Version (if applicable): 2.5.1

Baremetal or Container (if so, version):

Relevant Files

Model link:
EfficientLoFTR: https://drive.google.com/drive/folders/1nw1nhtInBfo65ux2I-GtaBXyTPLqtqnH
fp32 onnx model : https://drive.google.com/file/d/1jKXhOtj5-LfQqrRW0R4R7j3hcStGWTGW/view?usp=drive_link

Steps To Reproduce

The converted code(onnx2trt.py) I use to convert onnx to trt:

import os
import onnx
import tensorrt as trt
import numpy as np
import cv2
from tqdm import tqdm
import pycuda.driver as cuda
import pycuda.autoinit
import torch
from onnxsim import simplify

print(f"TensorRT Version:{trt.version}")

class CalibrationDataLoader:
def init(self, imgpath, input_shape=(1, 1, 640, 640), batch_size=2):
"""
Args:
imgpath (str): directory containing calibration images
input_shape (tuple): input shape (N,C,H,W)
batch_size (int): number of images per batch
"""
self.root_dir = imgpath
self.input_shape = input_shape
self.batch_size = batch_size
self.batch_idx = 0

    # Get all image files
    self.imgs = [os.path.join(imgpath, file) for file in os.listdir(imgpath) if file.endswith('jpg')]
    np.random.shuffle(self.imgs) 
    self.imgs = self.imgs[:300] 
    
    # Calculate maximum batch index
    self.max_batch_idx = len(self.imgs) // self.batch_size
    
    input_size = np.dtype(np.float32).itemsize * np.prod(input_shape)
    input_size = int(input_size) 
    
    self.device_input0 = cuda.mem_alloc(input_size)
    self.device_input1 = cuda.mem_alloc(input_size)
    
    print(f"Loaded {len(self.imgs)} images for calibration")
    
def preprocess_image(self, img_path):
    """Process single image following validation preprocessing"""
    try:
        # Read image in grayscale
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")
        
        # Resize to fixed 640x640
        img = cv2.resize(img, (640, 640))
        
        # Normalize to [0,1] and add batch/channel dimensions
        img = img.astype(np.float32) / 255.0
        img = img[None][None]  # Add batch and channel dimensions
        
        return img
        
    except Exception as e:
        print(f"Error processing image {img_path}: {str(e)}")
        return np.zeros(self.input_shape, dtype=np.float32)

def get_batch(self, names, p_str=None):
    """Get next pair of images for calibration"""
    try:
        if self.batch_idx >= self.max_batch_idx:
            return None
            
        # Get pair of image paths
        img_path0 = self.imgs[self.batch_idx * 2]
        img_path1 = self.imgs[self.batch_idx * 2 + 1]
        
        # Process both images
        img0 = self.preprocess_image(img_path0)
        img1 = self.preprocess_image(img_path1)
        
        # Copy to GPU
        cuda.memcpy_htod(self.device_input0, img0)
        cuda.memcpy_htod(self.device_input1, img1)
        
        self.batch_idx += 1
        print(f"Batch: [{self.batch_idx}/{self.max_batch_idx}]")
        return [int(self.device_input0), int(self.device_input1)]
        
    except Exception as e:
        print(f"Error in get_batch: {str(e)}")
        return None

class INT8Calibrator(trt.IInt8EntropyCalibrator):
def init(self, dataloader):
trt.IInt8EntropyCalibrator.init(self)
self.dataloader = dataloader
self.cache_file = "calibration.cache"

def get_batch_size(self):
    return self.dataloader.batch_size
    
def get_batch(self, names, p_str=None):
    return self.dataloader.get_batch(names, p_str)
    
def read_calibration_cache(self):
    if os.path.exists(self.cache_file):
        print(f"Loading calibration cache: {self.cache_file}")
        with open(self.cache_file, "rb") as f:
            return f.read()
            
def write_calibration_cache(self, cache):
    print(f"Writing calibration cache: {self.cache_file}")
    with open(self.cache_file, "wb") as f:
        f.write(cache)
        f.flush()

def build_int8_engine(onnx_path, imgpath):
"""Build TensorRT engine with INT8 quantization"""
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
parser = trt.OnnxParser(network, logger)

# Parse ONNX
print(f"Parsing ONNX file: {onnx_path}")
with open(onnx_path, 'rb') as model:
    if not parser.parse(model.read()):
        print('Failed to parse ONNX')
        for error in range(parser.num_errors):
            print(parser.get_error(error))
        return None

print("Building an engine from file {}, this may take a while...".format(onnx_path))

profile = builder.create_optimization_profile()
for i in range(network.num_inputs):
    input = network.get_input(i)
    profile.set_shape(
        input.name,
        min=[1, 1, 640, 640], 
        opt=[1, 1, 640, 640], 
        max=[1, 1, 640, 640] 
    )
config.add_optimization_profile(profile)

# config.add_optimization_profile(profile)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 8 << 30) 
config.set_flag(trt.BuilderFlag.INT8)
# config.set_flag(trt.BuilderFlag.FP32)

if not os.path.exists(imgpath):
    print(f"Error: Calibration image directory {imgpath} does not exist!")
    return None
    
# Create calibrator
dataloader = CalibrationDataLoader(imgpath)
if len(dataloader.imgs) == 0:
    print(f"Error: No calibration images found in {imgpath}")
    return None

config.int8_calibrator = INT8Calibrator(dataloader)

# Build engine
print('Building INT8 engine...')
try:
    serialized_engine = builder.build_serialized_network(network, config)
    if serialized_engine is None:
        print("Failed to build serialized engine")
        return None
        
    engine_path = onnx_path.replace('.onnx', '_int8.trt')
    print(f"Saving engine to: {engine_path}")
    with open(engine_path, "wb") as f:
        f.write(serialized_engine)
     
    runtime = trt.Runtime(logger)
    engine = runtime.deserialize_cuda_engine(serialized_engine)
    return engine
    
except Exception as e:
    print(f"Error building engine: {str(e)}")
    return None

def main():
# Configuration
imgpath = "./data/calibration_images/0015/images"
onnx_path = "./eloftr_outdoor_ds1_fp32_640x640_640x640_v17.onnx"

if not os.path.exists(onnx_path):
    print(f"Error: ONNX file {onnx_path} does not exist!")
    return
    
if not os.path.exists(imgpath):
    print(f"Error: Calibration image directory {imgpath} does not exist!")
    return

try:
    engine = build_int8_engine(onnx_path, imgpath)
    if engine is None:
        print("Failed to build engine")
        return
    print("INT8 Engine built successfully")
except Exception as e:
    print(f"Error building engine: {str(e)}")

if name == "main":
main()

Commands or scripts: python onnx2trt.py

Have you tried the latest release?: yes, I have also tried other versions of tensorrt, but each version reports a different error.

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): yes

@lix19937
Copy link

Try to use trtexec --int8 --onnx=spec --verbose to check.

@DuanHongxuan
Copy link
Author

DuanHongxuan commented Jan 15, 2025

Try to use trtexec --int8 --onnx=spec --verbose to check.

Thanks. But it doesn't seem to solve my problem.

.........
[V] [TRT] Running: QDQToCopy on /model/loftr_coarse/layers.5/aggregate/_input_quantizer/QuantizeLinear
[V] [TRT] Swap the layer type of /model/loftr_coarse/layers.5/aggregate/_input_quantizer/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/loftr_coarse/layers.5/aggregate/_input_quantizer_1/QuantizeLinear
[V] [TRT] Swap the layer type of /model/loftr_coarse/layers.5/aggregate/_input_quantizer_1/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/loftr_coarse/layers.6/aggregate/_input_quantizer/QuantizeLinear
[V] [TRT] Swap the layer type of /model/loftr_coarse/layers.6/aggregate/_input_quantizer/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/loftr_coarse/layers.6/aggregate/_input_quantizer_1/QuantizeLinear
[V] [TRT] Swap the layer type of /model/loftr_coarse/layers.6/aggregate/_input_quantizer_1/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/loftr_coarse/layers.7/aggregate/_input_quantizer/QuantizeLinear
[01/10/2025-08:52:04] [V] [TRT] Swap the layer type of /model/loftr_coarse/layers.7/aggregate/_input_quantizer/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/loftr_coarse/layers.7/aggregate/_input_quantizer_1/QuantizeLinear
[V] [TRT] Swap the layer type of /model/loftr_coarse/layers.7/aggregate/_input_quantizer_1/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/fine_preprocess/layer3_outconv/_input_quantizer/QuantizeLinear
[V] [TRT] Swap the layer type of /model/fine_preprocess/layer3_outconv/_input_quantizer/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/fine_preprocess/layer2_outconv2/layer2_outconv2.0/_input_quantizer/QuantizeLinear
[V] [TRT] Swap the layer type of /model/fine_preprocess/layer2_outconv2/layer2_outconv2.0/_input_quantizer/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: QDQToCopy on /model/fine_preprocess/layer1_outconv2/layer1_outconv2.0/_input_quantizer/QuantizeLinear
[V] [TRT] Swap the layer type of /model/fine_preprocess/layer1_outconv2/layer1_outconv2.0/_input_quantizer/QuantizeLinear from QUANTIZE to kQDQ
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.0.aggregate.weight + /model/loftr_coarse/layers.0/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.1.aggregate.weight + /model/loftr_coarse/layers.1/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.2.aggregate.weight + /model/loftr_coarse/layers.2/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.3.aggregate.weight + /model/loftr_coarse/layers.3/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.4.aggregate.weight + /model/loftr_coarse/layers.4/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.5.aggregate.weight + /model/loftr_coarse/layers.5/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.6.aggregate.weight + /model/loftr_coarse/layers.6/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] Running: ConstantSplit on model.loftr_coarse.layers.7.aggregate.weight + /model/loftr_coarse/layers.7/aggregate/_weight_quantizer/QuantizeLinear
[V] [TRT] After dupe layer removal: 146 layers
[V] [TRT] After final dead-layer removal: 146 layers
[V] [TRT] After tensor merging: 146 layers
[V] [TRT] After vertical fusions: 146 layers
[V] [TRT] After dupe layer removal: 146 layers
[V] [TRT] After final dead-layer removal: 146 layers
[V] [TRT] After tensor merging: 146 layers
[V] [TRT] Replacing slice /model/Split with copy from /model/backbone/layer3.13/nonlinearity/Relu_output_0 to /model/Split_output_0
[V] [TRT] Replacing slice /model/Split_91 with copy from /model/backbone/layer3.13/nonlinearity/Relu_output_0 to /model/Split_output_1
[V] [TRT] After slice removal: 146 layers
[V] [TRT] Eliminating concatenation /model/Concat
[V] [TRT] Retargeting /model/Concat_image0_clone_0 to /model/backbone/layer0/rbr_reparam/_input_quantizer/QuantizeLinear_output_0
[V] [TRT] Retargeting /model/Concat_image1_clone_1 to /model/backbone/layer0/rbr_reparam/_input_quantizer/QuantizeLinear_output_0
[V] [TRT] After concat removal: 145 layers
[V] [TRT] Trying to split Reshape and strided tensor
[V] [TRT] Graph optimization time: 0.0440148 seconds.
[V] [TRT] Building graph using backend strategy 2
[I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[V] [TRT] Constructing optimization profile number 0 [1/1].
[V] [TRT] Applying generic optimizations to the graph for inference.
Segmentation fault (core dumped)

@lix19937
Copy link

Segmentation fault (core dumped)

1, use a tiny onnx like resnet50.onnx by trtexec to check the hardware env.

BTW , because your model is quantization model, and make sure is quantized by nv QAT tool.

Try to use trtexec --best --onnx=spec --verbose to check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants