Skip to content

Commit

Permalink
Add image enhancement script
Browse files Browse the repository at this point in the history
  • Loading branch information
dblasko committed Nov 12, 2023
1 parent 90a6e1f commit 55f43c6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dataset_generation/pretraining_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
os.makedirs(f"data/pretraining/{split}/imgs", exist_ok=True)
os.makedirs(f"data/pretraining/{split}/targets", exist_ok=True)

dataset = load_dataset("huggan/night2day")
dataset = load_dataset("geekyrakshit/LoL-Dataset")

train_size = int(0.85 * len(dataset["train"]))
val_size = int(0.10 * len(dataset["train"]))
Expand Down
96 changes: 96 additions & 0 deletions inference/enhance_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse

import torch
import torchvision.transforms as T
import torchvision.utils as vutils
from torchvision.utils import make_grid
from PIL import Image, ImageDraw, ImageFont
import os
import sys
import argparse
from torch.utils.data import DataLoader

sys.path.append(".")

from dataset_generation.PretrainingDataset import PretrainingDataset
from training.training_utils.CharbonnierLoss import CharbonnierLoss
from model.MIRNet.model import MIRNet
from training.train import validate

"""
Run this script to run model inference on a specified image and write the enhanced image to an output folder.
Usage: python inference/enhance_image.py -i <path_to_input_image> [-o <path_to_output_folder>]
or python inference/enhance_image.py --input_image_path <path_to_input_image> [--output_folder_path <path_to_output_folder>]
If the output folder is not specified, the enhanced image is written to the directory the script is run from.
"""

IMG_SIZE = 400
NUM_FEATURES = 64
MODEL_PATH = "model/weights/Mirnet_enhance_finetune-35-early-stopped_64x64.pth" # f"model/weights/Mirnet_enhance{99}_64x64.pth"


def run_inference(input_image_path, output_folder_path, device, model_path=MODEL_PATH):
model = MIRNet(num_features=NUM_FEATURES).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])

model.eval()
with torch.no_grad():
try:
img = Image.open(input_image_path)
img_tensor = T.Compose(
[
T.Resize(IMG_SIZE),
T.ToTensor(),
T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
]
)(img).unsqueeze(0)
img_tensor = img_tensor.to(device)

output = model(img_tensor)
except:
print("Could not open image - verify the provided path.")
return
try:
output_image = T.ToPILImage()(output.squeeze().cpu())
out_path = (
output_folder_path
if output_folder_path[-1] == "/"
else output_folder_path
+ "/"
+ input_image_path.split("/")[-1].split(".")[0]
+ "_enhanced.png"
)
output_image.save(out_path)
print('-> Enhanced image saved to "' + out_path + '".')
except:
print("Error: Could not save image - verify the provided path.")
return


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_image_path",
"-i",
help="Path to the input image to enhance.",
required=True,
)
parser.add_argument(
"--output_folder_path",
"-o",
help="Path to the output folder to save the enhanced image to.",
default=".",
)
args = parser.parse_args()

device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
print(f"-> {device.type} device detected.")

run_inference(args.input_image_path, args.output_folder_path, device)

0 comments on commit 55f43c6

Please sign in to comment.