File size: 2,627 Bytes
62ef5f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from PIL import Image
import torchvision.transforms as transforms
from streamlit_image_comparison import image_comparison
import numpy as np
import torch
import torchvision

######################################### Utils ########################################
video_extensions = ["mp4"]
image_extensions = ["png", "jpg"]


def check_type(file_name: str):
    for image_extension in image_extensions:
        if file_name.endswith(image_extension):
            return "image"
    for video_extension in video_extensions:
        if file_name.endswith(video_extension):
            return "video"
    return None


transform = transforms.Compose(
    [transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)


###################################### Load model ######################################
@st.cache_resource
def load_model():
    model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    model.eval()
    return model


model = load_model()
########################################## UI ##########################################
st.title("Colorization")

uploaded_file = st.file_uploader("Upload grayscale image or video", type=image_extensions + video_extensions)
if uploaded_file:
    # Image
    if check_type(file_name=uploaded_file.name) == "image":
        image = np.array(Image.open(uploaded_file), dtype=np.float32)

        input_tensor = torchvision.transforms.functional.normalize(
            torch.tensor(image).permute(2, 0, 1),
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ).unsqueeze(0)
        process_button = st.button("Process")
        if process_button:
            with st.spinner("Từ từ coi..."):
                prediction = model(input_tensor)
                segment = prediction["out"][0].permute(1, 2, 0)
                segment = segment.detach().numpy()

                st.image(segment)
                st.image(image)

                image_comparison(
                    img1=image,
                    img2=np.array(segment),
                    label1="Grayscale",
                    label2="Colorized",
                    make_responsive=True,
                    show_labels=True,
                )
    # Video
    else:
        # video = open(uploaded_file.name)
        st.video("https://youtu.be/dQw4w9WgXcQ")

hide_menu_style = """
        <style>
        #MainMenu {visibility: hidden; }
        footer {visibility: hidden;}
        </style>
        """
st.markdown(hide_menu_style, unsafe_allow_html=True)