Skip to content

Commit

Permalink
Make single-image inference script more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
dblasko committed Dec 20, 2023
1 parent fcf5885 commit 1c34cb1
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions inference/enhance_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PIL import Image
import sys
import argparse
import torchvision.utils as vutils

sys.path.append(".")

Expand Down Expand Up @@ -40,12 +41,16 @@ def run_inference(input_image_path, output_folder_path, device, model_path=MODEL
)(img).unsqueeze(0)
img_tensor = img_tensor.to(device)

if img_tensor.shape[2] % 8 != 0:
img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :]
if img_tensor.shape[3] % 8 != 0:
img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)]

output = model(img_tensor)
except:
print("Could not open image - verify the provided path.")
return
try:
output_image = T.ToPILImage()(output.squeeze().cpu())
out_path = (
output_folder_path
if output_folder_path[-1] == "/"
Expand All @@ -54,7 +59,7 @@ def run_inference(input_image_path, output_folder_path, device, model_path=MODEL
+ input_image_path.split("/")[-1].split(".")[0]
+ "_enhanced.png"
)
output_image.save(out_path)
vutils.save_image(output, open(out_path, "wb"))
print('-> Enhanced image saved to "' + out_path + '".')
except:
print("Error: Could not save image - verify the provided path.")
Expand Down

0 comments on commit 1c34cb1

Please sign in to comment.