Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 1995747

Browse files
bfinerananmarquesmgoin
authored
Convert input for quantized YOLOv8 (#1521) (#1628)
* Added call to function that skips the quantization of the input if the model is quantized * Removed bare exception * Move input data type conversion upstream so it is properly saved and used as input to ort * Style and quality fixes --------- Co-authored-by: Alexandre Marques <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent 9f7606c commit 1995747

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

src/sparseml/yolov8/trainers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from sparseml.optim.helpers import load_recipe_yaml_str
3131
from sparseml.pytorch.optim.manager import ScheduledModifierManager
32+
from sparseml.pytorch.sparsification.quantization import skip_onnx_input_quantize
3233
from sparseml.pytorch.utils import ModuleExporter
3334
from sparseml.pytorch.utils.helpers import download_framework_model_by_recipe_type
3435
from sparseml.pytorch.utils.logger import LoggerManager, PythonLogger, WANDBLogger
@@ -729,7 +730,13 @@ def export(self, **kwargs):
729730
else ["output0"],
730731
)
731732

732-
onnx.checker.check_model(os.path.join(save_dir, name))
733+
complete_path = os.path.join(save_dir, name)
734+
try:
735+
skip_onnx_input_quantize(complete_path, complete_path)
736+
except Exception:
737+
pass
738+
739+
onnx.checker.check_model(complete_path)
733740
deployment_folder = exporter.create_deployment_folder(onnx_model_name=name)
734741
if args["export_samples"]:
735742
trainer_config = get_cfg(cfg=DEFAULT_SPARSEML_CONFIG_PATH)

src/sparseml/yolov8/utils/export_samples.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,15 @@ def export_sample_inputs_outputs(
106106
preprocessed_batch = preprocess(batch=batch, device=device)
107107
image = preprocessed_batch["img"]
108108

109-
# Save inputs as numpy array
110-
_export_inputs(image, sample_in_dir, file_idx, save_inputs_as_uint8)
111109
# Save torch outputs as numpy array
112110
_export_torch_outputs(image, model, sample_out_dir_torch, file_idx)
111+
112+
# Convert input data type if needed
113+
if save_inputs_as_uint8:
114+
image = (255 * image).to(dtype=torch.uint8)
115+
116+
# Save inputs as numpy array
117+
_export_inputs(image, sample_in_dir, file_idx)
113118
# Save onnxruntime outputs as numpy array
114119
_export_ort_outputs(
115120
image.cpu().numpy(), ort_session, sample_out_dir_ort, file_idx
@@ -166,13 +171,9 @@ def _export_ort_outputs(
166171
numpy.savez(sample_output_filename, preds)
167172

168173

169-
def _export_inputs(
170-
image: torch.Tensor, sample_in_dir: str, file_idx: str, save_inputs_as_uint8: bool
171-
):
174+
def _export_inputs(image: torch.Tensor, sample_in_dir: str, file_idx: str):
172175

173176
sample_in = image.detach().to("cpu")
174-
if save_inputs_as_uint8:
175-
sample_in = (255 * sample_in).to(dtype=torch.uint8)
176177

177178
sample_input_filename = os.path.join(sample_in_dir, f"inp-{file_idx}.npz")
178179
numpy.savez(sample_input_filename, sample_in)

0 commit comments

Comments
 (0)