Skip to content

Commit 0a1a89d

Browse files
zldrobitpre-commit-ci[bot]glenn-jocher
authored
Fix TF exports >= 2GB (ultralytics#6292)
* Fix exporting saved_model: pb exceeds 2GB * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace TF v1.x API with TF v2.x API for saved_model export * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Clean up * Remove lambda in tf.function() * Revert "Remove lambda in tf.function()" to be compatible with TF v2.4 This reverts commit 46c7931f11dfdea6ae340c77287c35c30b9e0779. * Fix for pre-commit.ci * Cleanup1 * Cleanup2 * Backwards compatibility update * Update common.py * Update common.py * Cleanup3 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <[email protected]>
1 parent 3a80ec4 commit 0a1a89d

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

export.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,11 +247,11 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
247247

248248
def export_saved_model(model, im, file, dynamic,
249249
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
250-
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
250+
conf_thres=0.25, keras=False, prefix=colorstr('TensorFlow SavedModel:')):
251251
# YOLOv5 TensorFlow SavedModel export
252252
try:
253253
import tensorflow as tf
254-
from tensorflow import keras
254+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
255255

256256
from models.tf import TFDetect, TFModel
257257

@@ -262,13 +262,26 @@ def export_saved_model(model, im, file, dynamic,
262262
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
263263
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC order for TensorFlow
264264
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
265-
inputs = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
265+
inputs = tf.keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
266266
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
267-
keras_model = keras.Model(inputs=inputs, outputs=outputs)
267+
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
268268
keras_model.trainable = False
269269
keras_model.summary()
270-
keras_model.save(f, save_format='tf')
271-
270+
if keras:
271+
keras_model.save(f, save_format='tf')
272+
else:
273+
m = tf.function(lambda x: keras_model(x)) # full model
274+
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
275+
m = m.get_concrete_function(spec)
276+
frozen_func = convert_variables_to_constants_v2(m)
277+
tfm = tf.Module()
278+
tfm.__call__ = tf.function(lambda x: frozen_func(x), [spec])
279+
tfm.__call__(im)
280+
tf.saved_model.save(
281+
tfm,
282+
f,
283+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
284+
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
272285
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
273286
return keras_model, f
274287
except Exception as e:

models/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
359359
if saved_model: # SavedModel
360360
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
361361
import tensorflow as tf
362-
model = tf.keras.models.load_model(w)
362+
keras = False # assume TF1 saved_model
363+
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
363364
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
364365
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
365366
import tensorflow as tf
@@ -431,7 +432,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
431432
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
432433
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
433434
if self.saved_model: # SavedModel
434-
y = self.model(im, training=False).numpy()
435+
y = (self.model(im, training=False) if self.keras else self.model(im)[0]).numpy()
435436
elif self.pb: # GraphDef
436437
y = self.frozen_func(x=self.tf.constant(im)).numpy()
437438
elif self.tflite: # Lite

0 commit comments

Comments
 (0)