forked from microsoft/onnxruntime-extensions
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsuperresolution_e2e.py
161 lines (124 loc) · 7.03 KB
/
superresolution_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import io
import numpy as np
import onnxruntime as ort
import os
from pathlib import Path
from PIL import Image
_this_dirpath = Path(os.path.dirname(os.path.abspath(__file__)))
ONNX_MODEL = 'pytorch_superresolution.onnx'
ONNX_MODEL_WITH_PRE_POST_PROCESSING = 'pytorch_superresolution.with_pre_post_processing.onnx'
# Export pytorch superresolution model as per
# https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
def convert_pytorch_superresolution_to_onnx():
import torch.utils.model_zoo as model_zoo
import torch.onnx
# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # fix batch size to 1 for use in mobile scenarios
# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
# set the model to inference mode
torch_model.eval()
# Create random input to the model and run it
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
ONNX_MODEL, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output']) # the model's output names
def add_pre_post_processing(output_format: str = "png"):
# Use the pre-defined helper to add pre/post processing to a super resolution model based on YCbCr input.
# Note: if you're creating a custom pre/post processing pipeline, use
# `from onnxruntime_extensions.tools.pre_post_processing import *` to pull in the pre/post processing infrastructure
# and Step definitions.
from onnxruntime_extensions.tools import add_pre_post_processing_to_model as add_ppp
# ORT 1.14 and later support ONNX opset 18, which added antialiasing to the Resize operator.
# Results are much better when that can be used. Minimum opset is 16.
onnx_opset = 16
from packaging import version
if version.parse(ort.__version__) >= version.parse("1.14.0"):
onnx_opset = 18
# add the processing to the model and output a PNG format image. JPG is also valid.
add_ppp.superresolution(Path(ONNX_MODEL), Path(ONNX_MODEL_WITH_PRE_POST_PROCESSING), output_format, onnx_opset)
def _center_crop_to_square(img: Image):
if img.height != img.width:
target_size = img.width if img.width < img.height else img.height
w_start = int(np.floor((img.width - target_size) / 2))
w_end = w_start + target_size
h_start = int(np.floor((img.height - target_size) / 2))
h_end = h_start + target_size
return img.crop((w_start, h_start, w_end, h_end))
else:
return img
def run_updated_onnx_model():
from onnxruntime_extensions import get_library_path
so = ort.SessionOptions()
# register the custom operators for the image decode/encode pre/post processing provided by onnxruntime-extensions
# with onnxruntime. if we do not do this we'll get an error on model load about the operators not being found.
ortext_lib_path = get_library_path()
so.register_custom_ops_library(ortext_lib_path)
inference_session = ort.InferenceSession(ONNX_MODEL_WITH_PRE_POST_PROCESSING, so)
test_image_path = _this_dirpath / 'data' / 'super_res_input.png'
test_image_bytes = np.fromfile(test_image_path, dtype=np.uint8)
outputs = inference_session.run(['image_out'], {'image': test_image_bytes})
upsized_image_bytes = outputs[0]
original = Image.open(io.BytesIO(test_image_bytes))
updated = Image.open(io.BytesIO(upsized_image_bytes))
# centered crop of original to match the area processed
original_cropped = _center_crop_to_square(original)
return original_cropped, updated
if __name__ == '__main__':
# check onnxruntime-extensions version
import onnxruntime_extensions
from packaging import version
if version.parse(onnxruntime_extensions.__version__) < version.parse("0.6.0"):
# temporarily install using nightly until we have official release on pypi.
raise ImportError(
f"onnxruntime_extensions version 0.6 or later is required. {onnxruntime_extensions.__version__} is installed. "
"Please install the latest version using "
"`pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions`") # noqa
convert_pytorch_superresolution_to_onnx()
add_pre_post_processing('png')
original_img, updated_img = run_updated_onnx_model()
new_width, new_height = updated_img.size
# create a side-by-side image with both.
# do a plain resize of original to model input size followed by model output size
# so side-by-side is an easier comparison
resized_orig_img = original_img.resize((224, 224))
resized_orig_img = resized_orig_img.resize((new_width, new_height))
combined = Image.new('RGB', (new_width * 2, new_height))
combined.paste(resized_orig_img, (0, 0))
combined.paste(updated_img, (new_width, 0))
# NOTE: The output is significantly better with ONNX opset 18 as Resize supports anti-aliasing.
combined.show('Original resized vs original vs Super Resolution resized')