-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch2onnx.py
88 lines (74 loc) · 4.03 KB
/
torch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import argparse
import torch.nn as nn
import onnx
from model import CustomModel
from loguru import logger
def main(args):
model_path = args.model_path
img_size = args.img_size
opset = args.opset
onnx_name = args.onnx_name
model_name =args.model_name
target_size = args.target_size
device = args.device
batch_size = args.batch_size
if "cuda" in device and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# Load model
logger.info("Loading model...")
model = CustomModel(model_name=model_name, target_size=target_size, pretrained=False)
model.load_state_dict(torch.load(model_path, map_location=torch.device(device))["state_dict"])
model.eval()
# Setup for export
batch_size = 1 if batch_size==0 else batch_size
x = torch.randn(batch_size, 3, img_size, img_size, requires_grad=True)
torch_out = model(x)
try:
# Export the model
logger.info("Exporting model...")
if batch_size == 0:
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_name, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=opset, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
else:
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
onnx_name, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=opset, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'])
logger.info("Onnx model export successful!")
# Test exported model
logger.info("Checking generated onnx model")
onnx_model = onnx.load(onnx_name)
onnx.checker.check_model(onnx_model)
logger.info("Onnx model checked!")
except Exception as e:
logger.info(f"Exception occured : {e}")
def arguement_parser():
parser = argparse.ArgumentParser(description="Parse input for model training")
parser.add_argument('--model_path', type=str, default="/home/sahil/Documents/Classifiers/weight_files/classifier_statedict_ep4_0.937.pt", help="PyTorch model checkpoint path")
parser.add_argument('--target_size', type=int, default=6, help='Number of classes')
parser.add_argument('--model_name', type=str, default="resnet50", help="Model name from Timm")
parser.add_argument('--img_size', type=int, default=224, help='Input image size')
parser.add_argument('--opset', type=int, default=11, help='Opset value for exporting the model')
parser.add_argument('--onnx_name', type=str, default="classifier.onnx", help='Output model name(Onnx)')
parser.add_argument('--device', type=str, default="cuda", help='Device')
parser.add_argument('--batch_size', type=int, default=0, help='Batch_size, dynamic by default')
args = parser.parse_args()
return args
if __name__ == "__main__":
args = arguement_parser()
main(args)