Initial commit
This commit is contained in:
204
scripts/export_onnx_model.py
Normal file
204
scripts/export_onnx_model.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l
|
||||
from segment_anything.utils.onnx import SamOnnxModel
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import onnxruntime # type: ignore
|
||||
|
||||
onnxruntime_exists = True
|
||||
except ImportError:
|
||||
onnxruntime_exists = False
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", type=str, required=True, help="The filename to save the ONNX model to."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
default="default",
|
||||
help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-single-mask",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If true, the exported ONNX model will only return the best mask, "
|
||||
"instead of returning multiple masks. For high resolution images "
|
||||
"this can improve runtime when upscaling masks is expensive."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
type=int,
|
||||
default=17,
|
||||
help="The ONNX opset version to use. Must be >=11",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--quantize-out",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"If set, will quantize the model and save it with this name. "
|
||||
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gelu-approximate",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replace GELU operations with approximations using tanh. Useful "
|
||||
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-stability-score",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Replaces the model's predicted mask quality score with the stability "
|
||||
"score calculated on the low resolution masks using an offset of 1.0. "
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-extra-metrics",
|
||||
action="store_true",
|
||||
help=(
|
||||
"The model will return five results: (masks, scores, stability_scores, "
|
||||
"areas, low_res_logits) instead of the usual three. This can be "
|
||||
"significantly slower for high resolution outputs."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_export(
|
||||
model_type: str,
|
||||
checkpoint: str,
|
||||
output: str,
|
||||
opset: int,
|
||||
return_single_mask: bool,
|
||||
gelu_approximate: bool = False,
|
||||
use_stability_score: bool = False,
|
||||
return_extra_metrics=False,
|
||||
):
|
||||
print("Loading model...")
|
||||
if model_type == "vit_b":
|
||||
sam = build_sam_vit_b(checkpoint)
|
||||
elif model_type == "vit_l":
|
||||
sam = build_sam_vit_l(checkpoint)
|
||||
else:
|
||||
sam = build_sam(checkpoint)
|
||||
|
||||
onnx_model = SamOnnxModel(
|
||||
model=sam,
|
||||
return_single_mask=return_single_mask,
|
||||
use_stability_score=use_stability_score,
|
||||
return_extra_metrics=return_extra_metrics,
|
||||
)
|
||||
|
||||
if gelu_approximate:
|
||||
for n, m in onnx_model.named_modules():
|
||||
if isinstance(m, torch.nn.GELU):
|
||||
m.approximate = "tanh"
|
||||
|
||||
dynamic_axes = {
|
||||
"point_coords": {1: "num_points"},
|
||||
"point_labels": {1: "num_points"},
|
||||
}
|
||||
|
||||
embed_dim = sam.prompt_encoder.embed_dim
|
||||
embed_size = sam.prompt_encoder.image_embedding_size
|
||||
mask_input_size = [4 * x for x in embed_size]
|
||||
dummy_inputs = {
|
||||
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
||||
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
||||
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
||||
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
||||
"has_mask_input": torch.tensor([1], dtype=torch.float),
|
||||
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
||||
}
|
||||
|
||||
_ = onnx_model(**dummy_inputs)
|
||||
|
||||
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
with open(output, "wb") as f:
|
||||
print(f"Exporing onnx model to {output}...")
|
||||
torch.onnx.export(
|
||||
onnx_model,
|
||||
tuple(dummy_inputs.values()),
|
||||
f,
|
||||
export_params=True,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True,
|
||||
input_names=list(dummy_inputs.keys()),
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
if onnxruntime_exists:
|
||||
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
||||
ort_session = onnxruntime.InferenceSession(output)
|
||||
_ = ort_session.run(None, ort_inputs)
|
||||
print("Model has successfully been run with ONNXRuntime.")
|
||||
|
||||
|
||||
def to_numpy(tensor):
|
||||
return tensor.cpu().numpy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
run_export(
|
||||
model_type=args.model_type,
|
||||
checkpoint=args.checkpoint,
|
||||
output=args.output,
|
||||
opset=args.opset,
|
||||
return_single_mask=args.return_single_mask,
|
||||
gelu_approximate=args.gelu_approximate,
|
||||
use_stability_score=args.use_stability_score,
|
||||
return_extra_metrics=args.return_extra_metrics,
|
||||
)
|
||||
|
||||
if args.quantize_out is not None:
|
||||
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
||||
from onnxruntime.quantization import QuantType # type: ignore
|
||||
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
|
||||
|
||||
print(f"Quantizing model and writing to {args.quantize_out}...")
|
||||
quantize_dynamic(
|
||||
model_input=args.output,
|
||||
model_output=args.quantize_out,
|
||||
optimize_model=True,
|
||||
per_channel=False,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
)
|
||||
print("Done!")
|
||||
Reference in New Issue
Block a user