Skip to content

Commit 57fadd4

Browse files
authored
Fix TensorRT potential unordered binding addresses (ultralytics#5826)
* feat: change file suffix in pythonic way * fix: enforce binding addresses order * fix: enforce binding addresses order
1 parent 777d5ba commit 57fadd4

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
276276
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
277277

278278
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
279-
f = str(file).replace('.pt', '.engine') # TensorRT engine file
279+
f = file.with_suffix('.engine') # TensorRT engine file
280280
logger = trt.Logger(trt.Logger.INFO)
281281
if verbose:
282282
logger.min_severity = trt.Logger.Severity.VERBOSE
@@ -310,6 +310,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
310310
except Exception as e:
311311
LOGGER.info(f'\n{prefix} export failure: {e}')
312312

313+
313314
@torch.no_grad()
314315
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
315316
weights=ROOT / 'yolov5s.pt', # weights path

models/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import math
88
import platform
99
import warnings
10-
from collections import namedtuple
10+
from collections import OrderedDict, namedtuple
1111
from copy import copy
1212
from pathlib import Path
1313

@@ -326,14 +326,14 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
326326
logger = trt.Logger(trt.Logger.INFO)
327327
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
328328
model = runtime.deserialize_cuda_engine(f.read())
329-
bindings = dict()
329+
bindings = OrderedDict()
330330
for index in range(model.num_bindings):
331331
name = model.get_binding_name(index)
332332
dtype = trt.nptype(model.get_binding_dtype(index))
333333
shape = tuple(model.get_binding_shape(index))
334334
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
335335
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
336-
binding_addrs = {n: d.ptr for n, d in bindings.items()}
336+
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
337337
context = model.create_execution_context()
338338
batch_size = bindings['images'].shape[0]
339339
else: # TensorFlow model (TFLite, pb, saved_model)

0 commit comments

Comments
 (0)