SwinTExCo / UI.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
2.63 kB
import streamlit as st
from PIL import Image
import torchvision.transforms as transforms
from streamlit_image_comparison import image_comparison
import numpy as np
import torch
import torchvision
######################################### Utils ########################################
video_extensions = ["mp4"]
image_extensions = ["png", "jpg"]
def check_type(file_name: str):
for image_extension in image_extensions:
if file_name.endswith(image_extension):
return "image"
for video_extension in video_extensions:
if file_name.endswith(video_extension):
return "video"
return None
transform = transforms.Compose(
[transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)
###################################### Load model ######################################
@st.cache_resource
def load_model():
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()
return model
model = load_model()
########################################## UI ##########################################
st.title("Colorization")
uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions)
if uploaded_file:
# Image
if check_type(file_name=uploaded_file.name) == "image":
image = np.array(Image.open(uploaded_file), dtype=np.float32)
input_tensor = torchvision.transforms.functional.normalize(
torch.tensor(image).permute(2, 0, 1),
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
).unsqueeze(0)
process_button = st.button("Process")
if process_button:
with st.spinner("Từ từ coi..."):
prediction = model(input_tensor)
segment = prediction["out"][0].permute(1, 2, 0)
segment = segment.detach().numpy()
st.image(segment)
st.image(image)
image_comparison(
img1=image,
img2=np.array(segment),
label1="Grayscale",
label2="Colorized",
make_responsive=True,
show_labels=True,
)
# Video
else:
# video = open(uploaded_file.name)
st.video("https://youtu.be/dQw4w9WgXcQ")
hide_menu_style = """
<style>
#MainMenu {visibility: hidden; }
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_menu_style, unsafe_allow_html=True)