Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from streamlit_image_comparison import image_comparison | |
import numpy as np | |
import torch | |
import torchvision | |
######################################### Utils ######################################## | |
video_extensions = ["mp4"] | |
image_extensions = ["png", "jpg"] | |
def check_type(file_name: str): | |
for image_extension in image_extensions: | |
if file_name.endswith(image_extension): | |
return "image" | |
for video_extension in video_extensions: | |
if file_name.endswith(video_extension): | |
return "video" | |
return None | |
transform = transforms.Compose( | |
[transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] | |
) | |
###################################### Load model ###################################### | |
def load_model(): | |
model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True) | |
model.eval() | |
return model | |
model = load_model() | |
########################################## UI ########################################## | |
st.title("Colorization") | |
uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions) | |
if uploaded_file: | |
# Image | |
if check_type(file_name=uploaded_file.name) == "image": | |
image = np.array(Image.open(uploaded_file), dtype=np.float32) | |
input_tensor = torchvision.transforms.functional.normalize( | |
torch.tensor(image).permute(2, 0, 1), | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
).unsqueeze(0) | |
process_button = st.button("Process") | |
if process_button: | |
with st.spinner("Từ từ coi..."): | |
prediction = model(input_tensor) | |
segment = prediction["out"][0].permute(1, 2, 0) | |
segment = segment.detach().numpy() | |
st.image(segment) | |
st.image(image) | |
image_comparison( | |
img1=image, | |
img2=np.array(segment), | |
label1="Grayscale", | |
label2="Colorized", | |
make_responsive=True, | |
show_labels=True, | |
) | |
# Video | |
else: | |
# video = open(uploaded_file.name) | |
st.video("https://youtu.be/dQw4w9WgXcQ") | |
hide_menu_style = """ | |
<style> | |
#MainMenu {visibility: hidden; } | |
footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(hide_menu_style, unsafe_allow_html=True) | |