Skip to content

Commit

Permalink
Add base of demo app
Browse files Browse the repository at this point in the history
  • Loading branch information
dblasko committed Nov 16, 2023
1 parent 1d7860f commit 05255ae
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 1 deletion.
6 changes: 6 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[theme]
primaryColor="#a06c4b"
backgroundColor="#110305"
secondaryBackgroundColor="#121010"
textColor="#a06c4b"
font="sans serif"
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,9 @@ To add further tests, simply add a new file in the `tests` folder, and name it `
*Coming soon.*

## Running the web application
*Coming soon.*
To start the inference web application, run the following command from the root directory of the project:
```bash
$ streamlit run app/app.py
```
The web application should then be accessible at `localhost:8501` in your browser and allow you to upload images in any size and format, enhance them and download the enhanced version:
![Screenshot from the demonstration application](utils/documentation_images/demo_app.png)
125 changes: 125 additions & 0 deletions app/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import streamlit as st
from PIL import Image
import io
import sys, os
import torch
import torchvision.transforms as T
import torchvision.utils as vutils
import base64

sys.path.append(".")
from model.MIRNet.model import MIRNet


def run_model(input_image):
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
print(f"-> {device.type} device detected.")

model = MIRNet(num_features=64).to(device)
checkpoint = torch.load(
"model/weights/Mirnet_enhance_finetune-35-early-stopped_64x64.pth",
map_location=device,
)
model.load_state_dict(checkpoint["model_state_dict"])

model.eval()
with torch.no_grad():
img = input_image
img_tensor = T.Compose(
[
T.Resize(400),
T.ToTensor(),
T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
]
)(img).unsqueeze(0)
img_tensor = img_tensor.to(device)

if img_tensor.shape[2] % 8 != 0:
img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :]
if img_tensor.shape[3] % 8 != 0:
img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)]

output = model(img_tensor)

vutils.save_image(output, open(f"temp.png", "wb"))
output_image = Image.open("temp.png")
os.remove("temp.png")
return output_image


def get_base64_font(font_path):
with open(font_path, "rb") as font_file:
return base64.b64encode(font_file.read()).decode()


st.set_page_config(layout="wide")

font_name = "Gloock"
gloock_b64 = get_base64_font("utils/assets/Gloock-Regular.ttf")
font_name_text = "Merriweather sans"
merri_b64 = get_base64_font("utils/assets/MerriweatherSans-Regular.ttf")
hide_streamlit_style = f"""
<style>
#MainMenu {'{visibility: hidden;}'}
footer {'{visibility: hidden;}'}
@font-face {{
font-family: '{font_name}';
src: url(data:font/ttf;base64,{gloock_b64}) format('truetype');
}}
@font-face {{
font-family: '{font_name_text}';
src: url(data:font/ttf;base64,{merri_b64}) format('truetype');
}}
span {{
font-family: '{font_name_text}';
}}
.e1nzilvr1, .st-emotion-cache-10trblm {{
font-family: '{font_name}';
font-size: 65px;
}}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)

st.title("Low-light event-image enhancement with MIRNet.")

# File uploader widget
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# To read file as bytes:
bytes_data = uploaded_file.getvalue()
image = Image.open(io.BytesIO(bytes_data)).convert("RGB")

# Create two columns for images
col1, col2 = st.columns(2)

with col1:
st.image(image, caption="Original Image", use_column_width="always")

# Button to enhance image
if st.button("Enhance Image"):
with col2:
# Assume your model has a function 'enhance' to enhance the image
enhanced_image = run_model(image)
st.image(
enhanced_image, caption="Enhanced Image", use_column_width="always"
)

# Download button
buf = io.BytesIO()
enhanced_image.save(buf, format="JPEG")
byte_im = buf.getvalue()
st.download_button(
label="Download image",
data=byte_im,
file_name="enhanced_image.jpg",
mime="image/jpeg",
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ wandb
pyyaml
tqdm
pytest
streamlit
Binary file added utils/assets/Gloock-Regular.ttf
Binary file not shown.
Binary file added utils/assets/MerriweatherSans-Regular.ttf
Binary file not shown.
Binary file added utils/documentation_images/demo_app.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 05255ae

Please sign in to comment.