Spaces:
Runtime error
Runtime error
Commit
·
a5291ee
0
Parent(s):
Duplicate from krystaltechnology/image-video-colorization
Browse files- .gitattributes +34 -0
- .streamlit/config.toml +8 -0
- 01_B&W_Videos_Colorizer.py +130 -0
- README.md +13 -0
- models/deep_colorization/colorizers/__init__.py +6 -0
- models/deep_colorization/colorizers/__pycache__/__init__.cpython-310.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/__init__.cpython-37.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/base_color.cpython-310.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/base_color.cpython-37.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/eccv16.cpython-310.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/eccv16.cpython-37.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-310.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-37.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/util.cpython-310.pyc +0 -0
- models/deep_colorization/colorizers/__pycache__/util.cpython-37.pyc +0 -0
- models/deep_colorization/colorizers/base_color.py +24 -0
- models/deep_colorization/colorizers/eccv16.py +105 -0
- models/deep_colorization/colorizers/siggraph17.py +168 -0
- models/deep_colorization/colorizers/util.py +47 -0
- pages/02_Input_Youtube_Link.py +129 -0
- pages/03_B&W_Images_Colorizer.py +98 -0
- pages/04_Super_Resolution.py +254 -0
- pages/05_Image_Denoizer.py +250 -0
- requirements.txt +25 -0
- utils.py +67 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor="#ff6328"
|
3 |
+
backgroundColor="#FFFFFF"
|
4 |
+
secondaryBackgroundColor="#F0F2F6"
|
5 |
+
textColor="#262730"
|
6 |
+
font="sans serif"
|
7 |
+
[server]
|
8 |
+
maxUploadSize=1028
|
01_B&W_Videos_Colorizer.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
|
5 |
+
os.environ["IMAGEIO_FFMPEG_EXE"] = "/usr/bin/ffmpeg"
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import moviepy.editor as mp
|
9 |
+
import numpy as np
|
10 |
+
import streamlit as st
|
11 |
+
from streamlit_lottie import st_lottie
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from models.deep_colorization.colorizers import eccv16
|
15 |
+
from utils import load_lottieurl, format_time, colorize_frame, change_model
|
16 |
+
|
17 |
+
st.title("B&W Videos Colorizer")
|
18 |
+
|
19 |
+
st.write("""
|
20 |
+
##### Upload a black and white video and get a colorized version of it.
|
21 |
+
###### ➠ This space is using CPU Basic so it might take a while to colorize a video.""")
|
22 |
+
|
23 |
+
#st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide")
|
24 |
+
|
25 |
+
loaded_model = eccv16(pretrained=True).eval()
|
26 |
+
current_model = "None"
|
27 |
+
|
28 |
+
def main():
|
29 |
+
model = st.selectbox(
|
30 |
+
"Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for your task)",
|
31 |
+
["ECCV16", "SIGGRAPH17"], index=0)
|
32 |
+
|
33 |
+
loaded_model = change_model(current_model, model)
|
34 |
+
st.write(f"Model is now {model}")
|
35 |
+
|
36 |
+
uploaded_file = st.file_uploader("Upload your video here...", type=['mp4', 'mov', 'avi', 'mkv'])
|
37 |
+
|
38 |
+
if st.button("Colorize"):
|
39 |
+
if uploaded_file is not None:
|
40 |
+
file_extension = os.path.splitext(uploaded_file.name)[1].lower()
|
41 |
+
if file_extension in ['.mp4', '.avi', '.mov', '.mkv']:
|
42 |
+
# Save the video file to a temporary location
|
43 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
44 |
+
temp_file.write(uploaded_file.read())
|
45 |
+
|
46 |
+
audio = mp.AudioFileClip(temp_file.name)
|
47 |
+
|
48 |
+
# Open the video using cv2.VideoCapture
|
49 |
+
video = cv2.VideoCapture(temp_file.name)
|
50 |
+
|
51 |
+
# Get video information
|
52 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
53 |
+
|
54 |
+
col1, col2 = st.columns([0.5, 0.5])
|
55 |
+
with col1:
|
56 |
+
st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
|
57 |
+
st.video(temp_file.name)
|
58 |
+
|
59 |
+
with col2:
|
60 |
+
st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
|
61 |
+
|
62 |
+
with st.spinner("Colorizing frames..."):
|
63 |
+
# Colorize video frames and store in a list
|
64 |
+
output_frames = []
|
65 |
+
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
66 |
+
progress_bar = st.progress(0) # Create a progress bar
|
67 |
+
|
68 |
+
start_time = time.time()
|
69 |
+
time_text = st.text("Time Remaining: ") # Initialize text value
|
70 |
+
|
71 |
+
for _ in tqdm(range(total_frames), unit='frame', desc="Progress"):
|
72 |
+
ret, frame = video.read()
|
73 |
+
if not ret:
|
74 |
+
break
|
75 |
+
|
76 |
+
colorized_frame = colorize_frame(frame, loaded_model)
|
77 |
+
output_frames.append((colorized_frame * 255).astype(np.uint8))
|
78 |
+
|
79 |
+
elapsed_time = time.time() - start_time
|
80 |
+
frames_completed = len(output_frames)
|
81 |
+
frames_remaining = total_frames - frames_completed
|
82 |
+
time_remaining = (frames_remaining / frames_completed) * elapsed_time
|
83 |
+
|
84 |
+
progress_bar.progress(frames_completed / total_frames) # Update progress bar
|
85 |
+
|
86 |
+
if frames_completed < total_frames:
|
87 |
+
time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value
|
88 |
+
else:
|
89 |
+
time_text.empty() # Remove text value
|
90 |
+
progress_bar.empty()
|
91 |
+
|
92 |
+
with st.spinner("Merging frames to video..."):
|
93 |
+
frame_size = output_frames[0].shape[:2]
|
94 |
+
output_filename = "output.mp4"
|
95 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video
|
96 |
+
out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0]))
|
97 |
+
|
98 |
+
# Display the colorized video using st.video
|
99 |
+
for frame in output_frames:
|
100 |
+
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
101 |
+
|
102 |
+
out.write(frame_bgr)
|
103 |
+
|
104 |
+
out.release()
|
105 |
+
|
106 |
+
# Convert the output video to a format compatible with Streamlit
|
107 |
+
converted_filename = "converted_output.mp4"
|
108 |
+
clip = mp.VideoFileClip(output_filename)
|
109 |
+
clip = clip.set_audio(audio)
|
110 |
+
|
111 |
+
clip.write_videofile(converted_filename, codec="libx264")
|
112 |
+
|
113 |
+
# Display the converted video using st.video()
|
114 |
+
st.video(converted_filename)
|
115 |
+
st.balloons()
|
116 |
+
|
117 |
+
# Add a download button for the colorized video
|
118 |
+
st.download_button(
|
119 |
+
label="Download Colorized Video",
|
120 |
+
data=open(converted_filename, "rb").read(),
|
121 |
+
file_name="colorized_video.mp4"
|
122 |
+
)
|
123 |
+
|
124 |
+
# Close and delete the temporary file after processing
|
125 |
+
video.release()
|
126 |
+
temp_file.close()
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
main()
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Image Video Colorization
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.21.0
|
8 |
+
app_file: 01_B&W_Videos_Colorizer.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: krystaltechnology/image-video-colorization
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
models/deep_colorization/colorizers/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .base_color import *
|
3 |
+
from .eccv16 import *
|
4 |
+
from .siggraph17 import *
|
5 |
+
from .util import *
|
6 |
+
|
models/deep_colorization/colorizers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (279 Bytes). View file
|
|
models/deep_colorization/colorizers/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (285 Bytes). View file
|
|
models/deep_colorization/colorizers/__pycache__/base_color.cpython-310.pyc
ADDED
Binary file (1.24 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/base_color.cpython-37.pyc
ADDED
Binary file (1.24 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/eccv16.cpython-310.pyc
ADDED
Binary file (3.27 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/eccv16.cpython-37.pyc
ADDED
Binary file (3.26 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-310.pyc
ADDED
Binary file (4.36 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/siggraph17.cpython-37.pyc
ADDED
Binary file (4.36 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/util.cpython-310.pyc
ADDED
Binary file (1.74 kB). View file
|
|
models/deep_colorization/colorizers/__pycache__/util.cpython-37.pyc
ADDED
Binary file (1.71 kB). View file
|
|
models/deep_colorization/colorizers/base_color.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class BaseColor(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(BaseColor, self).__init__()
|
8 |
+
|
9 |
+
self.l_cent = 50.
|
10 |
+
self.l_norm = 100.
|
11 |
+
self.ab_norm = 110.
|
12 |
+
|
13 |
+
def normalize_l(self, in_l):
|
14 |
+
return (in_l-self.l_cent)/self.l_norm
|
15 |
+
|
16 |
+
def unnormalize_l(self, in_l):
|
17 |
+
return in_l*self.l_norm + self.l_cent
|
18 |
+
|
19 |
+
def normalize_ab(self, in_ab):
|
20 |
+
return in_ab/self.ab_norm
|
21 |
+
|
22 |
+
def unnormalize_ab(self, in_ab):
|
23 |
+
return in_ab*self.ab_norm
|
24 |
+
|
models/deep_colorization/colorizers/eccv16.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
from IPython import embed
|
6 |
+
|
7 |
+
from .base_color import *
|
8 |
+
|
9 |
+
class ECCVGenerator(BaseColor):
|
10 |
+
def __init__(self, norm_layer=nn.BatchNorm2d):
|
11 |
+
super(ECCVGenerator, self).__init__()
|
12 |
+
|
13 |
+
model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
|
14 |
+
model1+=[nn.ReLU(True),]
|
15 |
+
model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
|
16 |
+
model1+=[nn.ReLU(True),]
|
17 |
+
model1+=[norm_layer(64),]
|
18 |
+
|
19 |
+
model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
|
20 |
+
model2+=[nn.ReLU(True),]
|
21 |
+
model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
|
22 |
+
model2+=[nn.ReLU(True),]
|
23 |
+
model2+=[norm_layer(128),]
|
24 |
+
|
25 |
+
model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
26 |
+
model3+=[nn.ReLU(True),]
|
27 |
+
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
28 |
+
model3+=[nn.ReLU(True),]
|
29 |
+
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
|
30 |
+
model3+=[nn.ReLU(True),]
|
31 |
+
model3+=[norm_layer(256),]
|
32 |
+
|
33 |
+
model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
34 |
+
model4+=[nn.ReLU(True),]
|
35 |
+
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
36 |
+
model4+=[nn.ReLU(True),]
|
37 |
+
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
38 |
+
model4+=[nn.ReLU(True),]
|
39 |
+
model4+=[norm_layer(512),]
|
40 |
+
|
41 |
+
model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
42 |
+
model5+=[nn.ReLU(True),]
|
43 |
+
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
44 |
+
model5+=[nn.ReLU(True),]
|
45 |
+
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
46 |
+
model5+=[nn.ReLU(True),]
|
47 |
+
model5+=[norm_layer(512),]
|
48 |
+
|
49 |
+
model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
50 |
+
model6+=[nn.ReLU(True),]
|
51 |
+
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
52 |
+
model6+=[nn.ReLU(True),]
|
53 |
+
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
54 |
+
model6+=[nn.ReLU(True),]
|
55 |
+
model6+=[norm_layer(512),]
|
56 |
+
|
57 |
+
model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
58 |
+
model7+=[nn.ReLU(True),]
|
59 |
+
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
60 |
+
model7+=[nn.ReLU(True),]
|
61 |
+
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
62 |
+
model7+=[nn.ReLU(True),]
|
63 |
+
model7+=[norm_layer(512),]
|
64 |
+
|
65 |
+
model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
|
66 |
+
model8+=[nn.ReLU(True),]
|
67 |
+
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
68 |
+
model8+=[nn.ReLU(True),]
|
69 |
+
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
70 |
+
model8+=[nn.ReLU(True),]
|
71 |
+
|
72 |
+
model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
|
73 |
+
|
74 |
+
self.model1 = nn.Sequential(*model1)
|
75 |
+
self.model2 = nn.Sequential(*model2)
|
76 |
+
self.model3 = nn.Sequential(*model3)
|
77 |
+
self.model4 = nn.Sequential(*model4)
|
78 |
+
self.model5 = nn.Sequential(*model5)
|
79 |
+
self.model6 = nn.Sequential(*model6)
|
80 |
+
self.model7 = nn.Sequential(*model7)
|
81 |
+
self.model8 = nn.Sequential(*model8)
|
82 |
+
|
83 |
+
self.softmax = nn.Softmax(dim=1)
|
84 |
+
self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
|
85 |
+
self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
|
86 |
+
|
87 |
+
def forward(self, input_l):
|
88 |
+
conv1_2 = self.model1(self.normalize_l(input_l))
|
89 |
+
conv2_2 = self.model2(conv1_2)
|
90 |
+
conv3_3 = self.model3(conv2_2)
|
91 |
+
conv4_3 = self.model4(conv3_3)
|
92 |
+
conv5_3 = self.model5(conv4_3)
|
93 |
+
conv6_3 = self.model6(conv5_3)
|
94 |
+
conv7_3 = self.model7(conv6_3)
|
95 |
+
conv8_3 = self.model8(conv7_3)
|
96 |
+
out_reg = self.model_out(self.softmax(conv8_3))
|
97 |
+
|
98 |
+
return self.unnormalize_ab(self.upsample4(out_reg))
|
99 |
+
|
100 |
+
def eccv16(pretrained=True):
|
101 |
+
model = ECCVGenerator()
|
102 |
+
if(pretrained):
|
103 |
+
import torch.utils.model_zoo as model_zoo
|
104 |
+
model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
|
105 |
+
return model
|
models/deep_colorization/colorizers/siggraph17.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .base_color import *
|
5 |
+
|
6 |
+
class SIGGRAPHGenerator(BaseColor):
|
7 |
+
def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
|
8 |
+
super(SIGGRAPHGenerator, self).__init__()
|
9 |
+
|
10 |
+
# Conv1
|
11 |
+
model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
|
12 |
+
model1+=[nn.ReLU(True),]
|
13 |
+
model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
|
14 |
+
model1+=[nn.ReLU(True),]
|
15 |
+
model1+=[norm_layer(64),]
|
16 |
+
# add a subsampling operation
|
17 |
+
|
18 |
+
# Conv2
|
19 |
+
model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
|
20 |
+
model2+=[nn.ReLU(True),]
|
21 |
+
model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
|
22 |
+
model2+=[nn.ReLU(True),]
|
23 |
+
model2+=[norm_layer(128),]
|
24 |
+
# add a subsampling layer operation
|
25 |
+
|
26 |
+
# Conv3
|
27 |
+
model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
28 |
+
model3+=[nn.ReLU(True),]
|
29 |
+
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
30 |
+
model3+=[nn.ReLU(True),]
|
31 |
+
model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
32 |
+
model3+=[nn.ReLU(True),]
|
33 |
+
model3+=[norm_layer(256),]
|
34 |
+
# add a subsampling layer operation
|
35 |
+
|
36 |
+
# Conv4
|
37 |
+
model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
38 |
+
model4+=[nn.ReLU(True),]
|
39 |
+
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
40 |
+
model4+=[nn.ReLU(True),]
|
41 |
+
model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
42 |
+
model4+=[nn.ReLU(True),]
|
43 |
+
model4+=[norm_layer(512),]
|
44 |
+
|
45 |
+
# Conv5
|
46 |
+
model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
47 |
+
model5+=[nn.ReLU(True),]
|
48 |
+
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
49 |
+
model5+=[nn.ReLU(True),]
|
50 |
+
model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
51 |
+
model5+=[nn.ReLU(True),]
|
52 |
+
model5+=[norm_layer(512),]
|
53 |
+
|
54 |
+
# Conv6
|
55 |
+
model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
56 |
+
model6+=[nn.ReLU(True),]
|
57 |
+
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
58 |
+
model6+=[nn.ReLU(True),]
|
59 |
+
model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
|
60 |
+
model6+=[nn.ReLU(True),]
|
61 |
+
model6+=[norm_layer(512),]
|
62 |
+
|
63 |
+
# Conv7
|
64 |
+
model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
65 |
+
model7+=[nn.ReLU(True),]
|
66 |
+
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
67 |
+
model7+=[nn.ReLU(True),]
|
68 |
+
model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
|
69 |
+
model7+=[nn.ReLU(True),]
|
70 |
+
model7+=[norm_layer(512),]
|
71 |
+
|
72 |
+
# Conv7
|
73 |
+
model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
|
74 |
+
model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
75 |
+
|
76 |
+
model8=[nn.ReLU(True),]
|
77 |
+
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
78 |
+
model8+=[nn.ReLU(True),]
|
79 |
+
model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
|
80 |
+
model8+=[nn.ReLU(True),]
|
81 |
+
model8+=[norm_layer(256),]
|
82 |
+
|
83 |
+
# Conv9
|
84 |
+
model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
|
85 |
+
model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
|
86 |
+
# add the two feature maps above
|
87 |
+
|
88 |
+
model9=[nn.ReLU(True),]
|
89 |
+
model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
|
90 |
+
model9+=[nn.ReLU(True),]
|
91 |
+
model9+=[norm_layer(128),]
|
92 |
+
|
93 |
+
# Conv10
|
94 |
+
model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
|
95 |
+
model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
|
96 |
+
# add the two feature maps above
|
97 |
+
|
98 |
+
model10=[nn.ReLU(True),]
|
99 |
+
model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
|
100 |
+
model10+=[nn.LeakyReLU(negative_slope=.2),]
|
101 |
+
|
102 |
+
# classification output
|
103 |
+
model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
|
104 |
+
|
105 |
+
# regression output
|
106 |
+
model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
|
107 |
+
model_out+=[nn.Tanh()]
|
108 |
+
|
109 |
+
self.model1 = nn.Sequential(*model1)
|
110 |
+
self.model2 = nn.Sequential(*model2)
|
111 |
+
self.model3 = nn.Sequential(*model3)
|
112 |
+
self.model4 = nn.Sequential(*model4)
|
113 |
+
self.model5 = nn.Sequential(*model5)
|
114 |
+
self.model6 = nn.Sequential(*model6)
|
115 |
+
self.model7 = nn.Sequential(*model7)
|
116 |
+
self.model8up = nn.Sequential(*model8up)
|
117 |
+
self.model8 = nn.Sequential(*model8)
|
118 |
+
self.model9up = nn.Sequential(*model9up)
|
119 |
+
self.model9 = nn.Sequential(*model9)
|
120 |
+
self.model10up = nn.Sequential(*model10up)
|
121 |
+
self.model10 = nn.Sequential(*model10)
|
122 |
+
self.model3short8 = nn.Sequential(*model3short8)
|
123 |
+
self.model2short9 = nn.Sequential(*model2short9)
|
124 |
+
self.model1short10 = nn.Sequential(*model1short10)
|
125 |
+
|
126 |
+
self.model_class = nn.Sequential(*model_class)
|
127 |
+
self.model_out = nn.Sequential(*model_out)
|
128 |
+
|
129 |
+
self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
|
130 |
+
self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
|
131 |
+
|
132 |
+
def forward(self, input_A, input_B=None, mask_B=None):
|
133 |
+
if(input_B is None):
|
134 |
+
input_B = torch.cat((input_A*0, input_A*0), dim=1)
|
135 |
+
if(mask_B is None):
|
136 |
+
mask_B = input_A*0
|
137 |
+
|
138 |
+
conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
|
139 |
+
conv2_2 = self.model2(conv1_2[:,:,::2,::2])
|
140 |
+
conv3_3 = self.model3(conv2_2[:,:,::2,::2])
|
141 |
+
conv4_3 = self.model4(conv3_3[:,:,::2,::2])
|
142 |
+
conv5_3 = self.model5(conv4_3)
|
143 |
+
conv6_3 = self.model6(conv5_3)
|
144 |
+
conv7_3 = self.model7(conv6_3)
|
145 |
+
|
146 |
+
conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
|
147 |
+
conv8_3 = self.model8(conv8_up)
|
148 |
+
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
|
149 |
+
conv9_3 = self.model9(conv9_up)
|
150 |
+
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
|
151 |
+
conv10_2 = self.model10(conv10_up)
|
152 |
+
out_reg = self.model_out(conv10_2)
|
153 |
+
|
154 |
+
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
|
155 |
+
conv9_3 = self.model9(conv9_up)
|
156 |
+
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
|
157 |
+
conv10_2 = self.model10(conv10_up)
|
158 |
+
out_reg = self.model_out(conv10_2)
|
159 |
+
|
160 |
+
return self.unnormalize_ab(out_reg)
|
161 |
+
|
162 |
+
def siggraph17(pretrained=True):
|
163 |
+
model = SIGGRAPHGenerator()
|
164 |
+
if(pretrained):
|
165 |
+
import torch.utils.model_zoo as model_zoo
|
166 |
+
model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
|
167 |
+
return model
|
168 |
+
|
models/deep_colorization/colorizers/util.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from skimage import color
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from IPython import embed
|
8 |
+
|
9 |
+
def load_img(img_path):
|
10 |
+
out_np = np.asarray(Image.open(img_path))
|
11 |
+
if(out_np.ndim==2):
|
12 |
+
out_np = np.tile(out_np[:,:,None],3)
|
13 |
+
return out_np
|
14 |
+
|
15 |
+
def resize_img(img, HW=(256,256), resample=3):
|
16 |
+
return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
|
17 |
+
|
18 |
+
def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
|
19 |
+
# return original size L and resized L as torch Tensors
|
20 |
+
img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
|
21 |
+
|
22 |
+
img_lab_orig = color.rgb2lab(img_rgb_orig)
|
23 |
+
img_lab_rs = color.rgb2lab(img_rgb_rs)
|
24 |
+
|
25 |
+
img_l_orig = img_lab_orig[:,:,0]
|
26 |
+
img_l_rs = img_lab_rs[:,:,0]
|
27 |
+
|
28 |
+
tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
|
29 |
+
tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
|
30 |
+
|
31 |
+
return (tens_orig_l, tens_rs_l)
|
32 |
+
|
33 |
+
def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
|
34 |
+
# tens_orig_l 1 x 1 x H_orig x W_orig
|
35 |
+
# out_ab 1 x 2 x H x W
|
36 |
+
|
37 |
+
HW_orig = tens_orig_l.shape[2:]
|
38 |
+
HW = out_ab.shape[2:]
|
39 |
+
|
40 |
+
# call resize function if needed
|
41 |
+
if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
|
42 |
+
out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
|
43 |
+
else:
|
44 |
+
out_ab_orig = out_ab
|
45 |
+
|
46 |
+
out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
|
47 |
+
return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
|
pages/02_Input_Youtube_Link.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import moviepy.editor as mp
|
5 |
+
import numpy as np
|
6 |
+
import streamlit as st
|
7 |
+
from pytube import YouTube
|
8 |
+
from streamlit_lottie import st_lottie
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from models.deep_colorization.colorizers import eccv16
|
12 |
+
from utils import colorize_frame, format_time
|
13 |
+
from utils import load_lottieurl, change_model
|
14 |
+
|
15 |
+
#st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide")
|
16 |
+
|
17 |
+
|
18 |
+
loaded_model = eccv16(pretrained=True).eval()
|
19 |
+
current_model = "None"
|
20 |
+
|
21 |
+
st.title("Image & Video Colorizer")
|
22 |
+
|
23 |
+
st.write("""
|
24 |
+
##### Input a YouTube black and white video link and get a colorized version of it.
|
25 |
+
###### ➠ This space is using CPU Basic so it might take a while to colorize a video.""")
|
26 |
+
|
27 |
+
@st.cache_data()
|
28 |
+
def download_video(link):
|
29 |
+
yt = YouTube(link)
|
30 |
+
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download(filename="video.mp4")
|
31 |
+
return video
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
model = st.selectbox(
|
36 |
+
"Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for you task)",
|
37 |
+
["ECCV16", "SIGGRAPH17"], index=0)
|
38 |
+
|
39 |
+
loaded_model = change_model(current_model, model)
|
40 |
+
st.write(f"Model is now {model}")
|
41 |
+
|
42 |
+
link = st.text_input("YouTube Link (The longer the video, the longer the processing time)")
|
43 |
+
if st.button("Colorize"):
|
44 |
+
if link is not "":
|
45 |
+
print(link)
|
46 |
+
yt_video = download_video(link)
|
47 |
+
print(yt_video)
|
48 |
+
col1, col2 = st.columns([0.5, 0.5])
|
49 |
+
with col1:
|
50 |
+
st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
|
51 |
+
st.video(yt_video)
|
52 |
+
with col2:
|
53 |
+
st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
|
54 |
+
with st.spinner("Colorizing frames..."):
|
55 |
+
# Colorize video frames and store in a list
|
56 |
+
output_frames = []
|
57 |
+
|
58 |
+
audio = mp.AudioFileClip("video.mp4")
|
59 |
+
video = cv2.VideoCapture("video.mp4")
|
60 |
+
|
61 |
+
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
62 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
63 |
+
|
64 |
+
progress_bar = st.progress(0) # Create a progress bar
|
65 |
+
start_time = time.time()
|
66 |
+
time_text = st.text("Time Remaining: ") # Initialize text value
|
67 |
+
|
68 |
+
for _ in tqdm(range(total_frames), unit='frame', desc="Progress"):
|
69 |
+
ret, frame = video.read()
|
70 |
+
if not ret:
|
71 |
+
break
|
72 |
+
|
73 |
+
colorized_frame = colorize_frame(frame, loaded_model)
|
74 |
+
output_frames.append((colorized_frame * 255).astype(np.uint8))
|
75 |
+
|
76 |
+
elapsed_time = time.time() - start_time
|
77 |
+
frames_completed = len(output_frames)
|
78 |
+
frames_remaining = total_frames - frames_completed
|
79 |
+
time_remaining = (frames_remaining / frames_completed) * elapsed_time
|
80 |
+
|
81 |
+
progress_bar.progress(frames_completed / total_frames) # Update progress bar
|
82 |
+
|
83 |
+
if frames_completed < total_frames:
|
84 |
+
time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value
|
85 |
+
else:
|
86 |
+
time_text.empty() # Remove text value
|
87 |
+
progress_bar.empty()
|
88 |
+
|
89 |
+
with st.spinner("Merging frames to video..."):
|
90 |
+
frame_size = output_frames[0].shape[:2]
|
91 |
+
output_filename = "output.mp4"
|
92 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video
|
93 |
+
out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0]))
|
94 |
+
|
95 |
+
# Display the colorized video using st.video
|
96 |
+
for frame in output_frames:
|
97 |
+
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
98 |
+
|
99 |
+
out.write(frame_bgr)
|
100 |
+
|
101 |
+
out.release()
|
102 |
+
|
103 |
+
# Convert the output video to a format compatible with Streamlit
|
104 |
+
converted_filename = "converted_output.mp4"
|
105 |
+
clip = mp.VideoFileClip(output_filename)
|
106 |
+
clip = clip.set_audio(audio)
|
107 |
+
|
108 |
+
clip.write_videofile(converted_filename, codec="libx264")
|
109 |
+
|
110 |
+
# Display the converted video using st.video()
|
111 |
+
st.video(converted_filename)
|
112 |
+
st.balloons()
|
113 |
+
|
114 |
+
# Add a download button for the colorized video
|
115 |
+
st.download_button(
|
116 |
+
label="Download Colorized Video",
|
117 |
+
data=open(converted_filename, "rb").read(),
|
118 |
+
file_name="colorized_video.mp4"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Close and delete the temporary file after processing
|
122 |
+
video.release()
|
123 |
+
else:
|
124 |
+
st.warning('Please Type a link', icon="⚠️")
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
main()
|
129 |
+
|
pages/03_B&W_Images_Colorizer.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zipfile
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from PIL import Image
|
6 |
+
from streamlit_lottie import st_lottie
|
7 |
+
|
8 |
+
from models.deep_colorization.colorizers import eccv16
|
9 |
+
from utils import colorize_image, change_model, load_lottieurl
|
10 |
+
|
11 |
+
#st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide")
|
12 |
+
|
13 |
+
st.title("B&W Images Colorizer")
|
14 |
+
|
15 |
+
|
16 |
+
loaded_model = eccv16(pretrained=True).eval()
|
17 |
+
current_model = "None"
|
18 |
+
|
19 |
+
st.write("""
|
20 |
+
##### Input a black and white image and get a colorized version of it.
|
21 |
+
###### ➠ If you want to colorize multiple images just upload them all at once.
|
22 |
+
###### ➠ Uploading already colored images won't raise errors but images won't look good.""")
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
model = st.selectbox(
|
27 |
+
"Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for you task)",
|
28 |
+
["ECCV16", "SIGGRAPH17"], index=0)
|
29 |
+
|
30 |
+
# Make the user select a model
|
31 |
+
loaded_model = change_model(current_model, model)
|
32 |
+
st.write(f"Model is now {model}")
|
33 |
+
|
34 |
+
# Ask the user if he wants to see colorization
|
35 |
+
display_results = st.checkbox('Display results in real time', value=True)
|
36 |
+
|
37 |
+
# Input for the user to upload images
|
38 |
+
uploaded_file = st.file_uploader("Upload your images here...", type=['jpg', 'png', 'jpeg'],
|
39 |
+
accept_multiple_files=True)
|
40 |
+
|
41 |
+
# If the user clicks on the button
|
42 |
+
if st.button("Colorize"):
|
43 |
+
# If the user uploaded images
|
44 |
+
if uploaded_file is not None:
|
45 |
+
if display_results:
|
46 |
+
col1, col2 = st.columns([0.5, 0.5])
|
47 |
+
with col1:
|
48 |
+
st.markdown('<p style="text-align: center;">Before</p>', unsafe_allow_html=True)
|
49 |
+
with col2:
|
50 |
+
st.markdown('<p style="text-align: center;">After</p>', unsafe_allow_html=True)
|
51 |
+
else:
|
52 |
+
col1, col2, col3 = st.columns(3)
|
53 |
+
|
54 |
+
for i, file in enumerate(uploaded_file):
|
55 |
+
file_extension = os.path.splitext(file.name)[1].lower()
|
56 |
+
if file_extension in ['.jpg', '.png', '.jpeg']:
|
57 |
+
image = Image.open(file)
|
58 |
+
if display_results:
|
59 |
+
with col1:
|
60 |
+
st.image(image, use_column_width="always")
|
61 |
+
with col2:
|
62 |
+
with st.spinner("Colorizing image..."):
|
63 |
+
out_img, new_img = colorize_image(file, loaded_model)
|
64 |
+
new_img.save("IMG_" + str(i+1) + ".jpg")
|
65 |
+
st.image(out_img, use_column_width="always")
|
66 |
+
|
67 |
+
else:
|
68 |
+
out_img, new_img = colorize_image(file, loaded_model)
|
69 |
+
new_img.save("IMG_" + str(i+1) + ".jpg")
|
70 |
+
|
71 |
+
if len(uploaded_file) > 1:
|
72 |
+
# Create a zip file
|
73 |
+
zip_filename = "colorized_images.zip"
|
74 |
+
with zipfile.ZipFile(zip_filename, "w") as zip_file:
|
75 |
+
# Add colorized images to the zip file
|
76 |
+
for i in range(len(uploaded_file)):
|
77 |
+
zip_file.write("IMG_" + str(i + 1) + ".jpg", "IMG_" + str(i) + ".jpg")
|
78 |
+
with col2:
|
79 |
+
# Provide the zip file data for download
|
80 |
+
st.download_button(
|
81 |
+
label="Download Colorized Images" if len(uploaded_file) > 1 else "Download Colorized Image",
|
82 |
+
data=open(zip_filename, "rb").read(),
|
83 |
+
file_name=zip_filename,
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
with col2:
|
87 |
+
st.download_button(
|
88 |
+
label="Download Colorized Image",
|
89 |
+
data=open("IMG_1.jpg", "rb").read(),
|
90 |
+
file_name="IMG_1.jpg",
|
91 |
+
)
|
92 |
+
|
93 |
+
else:
|
94 |
+
st.warning('Upload a file', icon="⚠️")
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
main()
|
pages/04_Super_Resolution.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import cv2
|
3 |
+
import numpy
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from realesrgan import RealESRGANer
|
11 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
12 |
+
|
13 |
+
|
14 |
+
last_file = None
|
15 |
+
img_mode = "RGBA"
|
16 |
+
|
17 |
+
|
18 |
+
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
19 |
+
"""Real-ESRGAN function to restore (and upscale) images.
|
20 |
+
"""
|
21 |
+
if not img:
|
22 |
+
return
|
23 |
+
|
24 |
+
# Define model parameters
|
25 |
+
if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
|
26 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
27 |
+
netscale = 4
|
28 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
29 |
+
elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
|
30 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
31 |
+
netscale = 4
|
32 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
|
33 |
+
elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
|
34 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
35 |
+
netscale = 4
|
36 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
|
37 |
+
elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
|
38 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
39 |
+
netscale = 2
|
40 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
|
41 |
+
elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
|
42 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
43 |
+
netscale = 4
|
44 |
+
file_url = [
|
45 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
|
46 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
47 |
+
]
|
48 |
+
|
49 |
+
# Determine model paths
|
50 |
+
model_path = os.path.join('weights', model_name + '.pth')
|
51 |
+
if not os.path.isfile(model_path):
|
52 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
53 |
+
for url in file_url:
|
54 |
+
# model_path will be updated
|
55 |
+
model_path = load_file_from_url(
|
56 |
+
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
57 |
+
|
58 |
+
# Use dni to control the denoise strength
|
59 |
+
dni_weight = None
|
60 |
+
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
61 |
+
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
62 |
+
model_path = [model_path, wdn_model_path]
|
63 |
+
dni_weight = [denoise_strength, 1 - denoise_strength]
|
64 |
+
|
65 |
+
# Restorer Class
|
66 |
+
upsampler = RealESRGANer(
|
67 |
+
scale=netscale,
|
68 |
+
model_path=model_path,
|
69 |
+
dni_weight=dni_weight,
|
70 |
+
model=model,
|
71 |
+
tile=0,
|
72 |
+
tile_pad=10,
|
73 |
+
pre_pad=10,
|
74 |
+
half=False,
|
75 |
+
gpu_id=None
|
76 |
+
)
|
77 |
+
|
78 |
+
# Use GFPGAN for face enhancement
|
79 |
+
if face_enhance:
|
80 |
+
from gfpgan import GFPGANer
|
81 |
+
face_enhancer = GFPGANer(
|
82 |
+
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
83 |
+
upscale=outscale,
|
84 |
+
arch='clean',
|
85 |
+
channel_multiplier=2,
|
86 |
+
bg_upsampler=upsampler)
|
87 |
+
|
88 |
+
# Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
|
89 |
+
#cv_img = numpy.array(img.get_value(), dtype = 'uint8')
|
90 |
+
cv_img = numpy.array(img)
|
91 |
+
#img = cv2.cvtColor(cv2.UMat(imgUMat), cv2.COLOR_RGB2GRAY)
|
92 |
+
img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
|
93 |
+
|
94 |
+
# Apply restoration
|
95 |
+
try:
|
96 |
+
if face_enhance:
|
97 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
98 |
+
else:
|
99 |
+
output, _ = upsampler.enhance(img, outscale=outscale)
|
100 |
+
except RuntimeError as error:
|
101 |
+
print('Error', error)
|
102 |
+
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
103 |
+
else:
|
104 |
+
# Save restored image and return it to the output Image component
|
105 |
+
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
106 |
+
extension = 'png'
|
107 |
+
else:
|
108 |
+
extension = 'jpg'
|
109 |
+
|
110 |
+
out_filename = f"output_{rnd_string(8)}.{extension}"
|
111 |
+
cv2.imwrite(out_filename, output)
|
112 |
+
global last_file
|
113 |
+
last_file = out_filename
|
114 |
+
return out_filename
|
115 |
+
|
116 |
+
|
117 |
+
def rnd_string(x):
|
118 |
+
"""Returns a string of 'x' random characters
|
119 |
+
"""
|
120 |
+
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
|
121 |
+
result = "".join((random.choice(characters)) for i in range(x))
|
122 |
+
return result
|
123 |
+
|
124 |
+
|
125 |
+
def reset():
|
126 |
+
"""Resets the Image components of the Gradio interface and deletes
|
127 |
+
the last processed image
|
128 |
+
"""
|
129 |
+
global last_file
|
130 |
+
if last_file:
|
131 |
+
print(f"Deleting {last_file} ...")
|
132 |
+
os.remove(last_file)
|
133 |
+
last_file = None
|
134 |
+
return gr.update(value=None), gr.update(value=None)
|
135 |
+
|
136 |
+
|
137 |
+
def has_transparency(img):
|
138 |
+
"""This function works by first checking to see if a "transparency" property is defined
|
139 |
+
in the image's info -- if so, we return "True". Then, if the image is using indexed colors
|
140 |
+
(such as in GIFs), it gets the index of the transparent color in the palette
|
141 |
+
(img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
|
142 |
+
(img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
|
143 |
+
it, but it double-checks by getting the minimum and maximum values of every color channel
|
144 |
+
(img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
|
145 |
+
https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
|
146 |
+
"""
|
147 |
+
if img.info.get("transparency", None) is not None:
|
148 |
+
return True
|
149 |
+
if img.mode == "P":
|
150 |
+
transparent = img.info.get("transparency", -1)
|
151 |
+
for _, index in img.getcolors():
|
152 |
+
if index == transparent:
|
153 |
+
return True
|
154 |
+
elif img.mode == "RGBA":
|
155 |
+
extrema = img.getextrema()
|
156 |
+
if extrema[3][0] < 255:
|
157 |
+
return True
|
158 |
+
return False
|
159 |
+
|
160 |
+
|
161 |
+
def image_properties(img):
|
162 |
+
"""Returns the dimensions (width and height) and color mode of the input image and
|
163 |
+
also sets the global img_mode variable to be used by the realesrgan function
|
164 |
+
"""
|
165 |
+
global img_mode
|
166 |
+
if img:
|
167 |
+
if has_transparency(img):
|
168 |
+
img_mode = "RGBA"
|
169 |
+
else:
|
170 |
+
img_mode = "RGB"
|
171 |
+
properties = f"Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
|
172 |
+
return properties
|
173 |
+
|
174 |
+
def image_properties(image):
|
175 |
+
# Function to display image properties
|
176 |
+
properties = f"Image Size: {image.size}\nImage Mode: {image.mode}"
|
177 |
+
return properties
|
178 |
+
|
179 |
+
#----------
|
180 |
+
|
181 |
+
input_folder = '.'
|
182 |
+
|
183 |
+
@st.cache_resource
|
184 |
+
def load_image(image_file):
|
185 |
+
img = Image.open(image_file)
|
186 |
+
return img
|
187 |
+
|
188 |
+
def save_image(image_file):
|
189 |
+
if image_file is not None:
|
190 |
+
filename = image_file.name
|
191 |
+
img = load_image(image_file)
|
192 |
+
st.image(image=img, width=None)
|
193 |
+
with open(os.path.join(input_folder, filename), "wb") as f:
|
194 |
+
f.write(image_file.getbuffer())
|
195 |
+
st.success("Succesfully uploaded file for processing".format(filename))
|
196 |
+
|
197 |
+
#------------
|
198 |
+
|
199 |
+
st.title("Super Resolution")
|
200 |
+
# Saving uploaded image in input folder for processing
|
201 |
+
|
202 |
+
#with st.expander("Options/Parameters"):
|
203 |
+
|
204 |
+
input_img = st.file_uploader(
|
205 |
+
"Upload Image", type=['png', 'jpeg', 'jpg', 'webp'])
|
206 |
+
#save_image(input_img)
|
207 |
+
|
208 |
+
model_name = st.selectbox(
|
209 |
+
"Real-ESRGAN inference model to be used",
|
210 |
+
["RealESRGAN_x4plus", "RealESRNet_x4plus", "RealESRGAN_x4plus_anime_6B", "RealESRGAN_x2plus", "realesr-general-x4v3"],
|
211 |
+
index=4
|
212 |
+
)
|
213 |
+
|
214 |
+
#denoise_strength = st.slider("Denoise Strength (Used only with the realesr-general-x4v3 model)", 0.0, 1.0, 0.5)
|
215 |
+
denoise_strength = 0.5
|
216 |
+
|
217 |
+
outscale = st.slider("Image Upscaling Factor", 1, 10, 2)
|
218 |
+
|
219 |
+
face_enhance = st.checkbox("Face Enhancement using GFPGAN (Doesn't work for anime images)")
|
220 |
+
|
221 |
+
if input_img:
|
222 |
+
print(input_img)
|
223 |
+
input_img = Image.open(input_img)
|
224 |
+
# Display image properties
|
225 |
+
cols = st.columns(2)
|
226 |
+
|
227 |
+
cols[0].image(input_img, 'Source Image')
|
228 |
+
|
229 |
+
#input_properties = get_image_properties(input_img)
|
230 |
+
#cols[1].write(input_properties)
|
231 |
+
|
232 |
+
# Output placeholder
|
233 |
+
output_img = st.empty()
|
234 |
+
|
235 |
+
# Input and output placeholders
|
236 |
+
input_img = input_img
|
237 |
+
output_img = st.empty()
|
238 |
+
|
239 |
+
# Buttons
|
240 |
+
restore = st.button('Restore')
|
241 |
+
reset = st.button('Reset')
|
242 |
+
|
243 |
+
# Restore clicked
|
244 |
+
if restore:
|
245 |
+
if input_img is not None:
|
246 |
+
output = realesrgan(input_img, model_name, denoise_strength,
|
247 |
+
face_enhance, outscale)
|
248 |
+
output_img.image(output, 'Restored Image')
|
249 |
+
else:
|
250 |
+
st.warning('Upload a file', icon="⚠️")
|
251 |
+
|
252 |
+
# Reset clicked
|
253 |
+
if reset:
|
254 |
+
output_img.empty()
|
pages/05_Image_Denoizer.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import cv2
|
3 |
+
import numpy
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from realesrgan import RealESRGANer
|
11 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
12 |
+
|
13 |
+
|
14 |
+
last_file = None
|
15 |
+
img_mode = "RGBA"
|
16 |
+
|
17 |
+
|
18 |
+
def realesrgan(img, model_name, denoise_strength, face_enhance, outscale):
|
19 |
+
"""Real-ESRGAN function to restore (and upscale) images.
|
20 |
+
"""
|
21 |
+
if not img:
|
22 |
+
return
|
23 |
+
|
24 |
+
# Define model parameters
|
25 |
+
if model_name == 'RealESRGAN_x4plus': # x4 RRDBNet model
|
26 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
27 |
+
netscale = 4
|
28 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
29 |
+
elif model_name == 'RealESRNet_x4plus': # x4 RRDBNet model
|
30 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
31 |
+
netscale = 4
|
32 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth']
|
33 |
+
elif model_name == 'RealESRGAN_x4plus_anime_6B': # x4 RRDBNet model with 6 blocks
|
34 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
35 |
+
netscale = 4
|
36 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth']
|
37 |
+
elif model_name == 'RealESRGAN_x2plus': # x2 RRDBNet model
|
38 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
39 |
+
netscale = 2
|
40 |
+
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth']
|
41 |
+
elif model_name == 'realesr-general-x4v3': # x4 VGG-style model (S size)
|
42 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
43 |
+
netscale = 4
|
44 |
+
file_url = [
|
45 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth',
|
46 |
+
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
47 |
+
]
|
48 |
+
|
49 |
+
# Determine model paths
|
50 |
+
model_path = os.path.join('weights', model_name + '.pth')
|
51 |
+
if not os.path.isfile(model_path):
|
52 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
53 |
+
for url in file_url:
|
54 |
+
# model_path will be updated
|
55 |
+
model_path = load_file_from_url(
|
56 |
+
url=url, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
57 |
+
|
58 |
+
# Use dni to control the denoise strength
|
59 |
+
dni_weight = None
|
60 |
+
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
61 |
+
wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
62 |
+
model_path = [model_path, wdn_model_path]
|
63 |
+
dni_weight = [denoise_strength, 1 - denoise_strength]
|
64 |
+
|
65 |
+
# Restorer Class
|
66 |
+
upsampler = RealESRGANer(
|
67 |
+
scale=netscale,
|
68 |
+
model_path=model_path,
|
69 |
+
dni_weight=dni_weight,
|
70 |
+
model=model,
|
71 |
+
tile=0,
|
72 |
+
tile_pad=10,
|
73 |
+
pre_pad=10,
|
74 |
+
half=False,
|
75 |
+
gpu_id=None
|
76 |
+
)
|
77 |
+
|
78 |
+
# Use GFPGAN for face enhancement
|
79 |
+
if face_enhance:
|
80 |
+
from gfpgan import GFPGANer
|
81 |
+
face_enhancer = GFPGANer(
|
82 |
+
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
83 |
+
upscale=outscale,
|
84 |
+
arch='clean',
|
85 |
+
channel_multiplier=2,
|
86 |
+
bg_upsampler=upsampler)
|
87 |
+
|
88 |
+
# Convert the input PIL image to cv2 image, so that it can be processed by realesrgan
|
89 |
+
#cv_img = numpy.array(img.get_value(), dtype = 'uint8')
|
90 |
+
cv_img = numpy.array(img)
|
91 |
+
#img = cv2.cvtColor(cv2.UMat(imgUMat), cv2.COLOR_RGB2GRAY)
|
92 |
+
img = cv2.cvtColor(cv_img, cv2.COLOR_RGBA2BGRA)
|
93 |
+
|
94 |
+
# Apply restoration
|
95 |
+
try:
|
96 |
+
if face_enhance:
|
97 |
+
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
|
98 |
+
else:
|
99 |
+
output, _ = upsampler.enhance(img, outscale=outscale)
|
100 |
+
except RuntimeError as error:
|
101 |
+
print('Error', error)
|
102 |
+
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
|
103 |
+
else:
|
104 |
+
# Save restored image and return it to the output Image component
|
105 |
+
if img_mode == 'RGBA': # RGBA images should be saved in png format
|
106 |
+
extension = 'png'
|
107 |
+
else:
|
108 |
+
extension = 'jpg'
|
109 |
+
|
110 |
+
out_filename = f"output_{rnd_string(8)}.{extension}"
|
111 |
+
cv2.imwrite(out_filename, output)
|
112 |
+
global last_file
|
113 |
+
last_file = out_filename
|
114 |
+
return out_filename
|
115 |
+
|
116 |
+
|
117 |
+
def rnd_string(x):
|
118 |
+
"""Returns a string of 'x' random characters
|
119 |
+
"""
|
120 |
+
characters = "abcdefghijklmnopqrstuvwxyz_0123456789"
|
121 |
+
result = "".join((random.choice(characters)) for i in range(x))
|
122 |
+
return result
|
123 |
+
|
124 |
+
|
125 |
+
def reset():
|
126 |
+
"""Resets the Image components of the Gradio interface and deletes
|
127 |
+
the last processed image
|
128 |
+
"""
|
129 |
+
global last_file
|
130 |
+
if last_file:
|
131 |
+
print(f"Deleting {last_file} ...")
|
132 |
+
os.remove(last_file)
|
133 |
+
last_file = None
|
134 |
+
return gr.update(value=None), gr.update(value=None)
|
135 |
+
|
136 |
+
|
137 |
+
def has_transparency(img):
|
138 |
+
"""This function works by first checking to see if a "transparency" property is defined
|
139 |
+
in the image's info -- if so, we return "True". Then, if the image is using indexed colors
|
140 |
+
(such as in GIFs), it gets the index of the transparent color in the palette
|
141 |
+
(img.info.get("transparency", -1)) and checks if it's used anywhere in the canvas
|
142 |
+
(img.getcolors()). If the image is in RGBA mode, then presumably it has transparency in
|
143 |
+
it, but it double-checks by getting the minimum and maximum values of every color channel
|
144 |
+
(img.getextrema()), and checks if the alpha channel's smallest value falls below 255.
|
145 |
+
https://stackoverflow.com/questions/43864101/python-pil-check-if-image-is-transparent
|
146 |
+
"""
|
147 |
+
if img.info.get("transparency", None) is not None:
|
148 |
+
return True
|
149 |
+
if img.mode == "P":
|
150 |
+
transparent = img.info.get("transparency", -1)
|
151 |
+
for _, index in img.getcolors():
|
152 |
+
if index == transparent:
|
153 |
+
return True
|
154 |
+
elif img.mode == "RGBA":
|
155 |
+
extrema = img.getextrema()
|
156 |
+
if extrema[3][0] < 255:
|
157 |
+
return True
|
158 |
+
return False
|
159 |
+
|
160 |
+
|
161 |
+
def image_properties(img):
|
162 |
+
"""Returns the dimensions (width and height) and color mode of the input image and
|
163 |
+
also sets the global img_mode variable to be used by the realesrgan function
|
164 |
+
"""
|
165 |
+
global img_mode
|
166 |
+
if img:
|
167 |
+
if has_transparency(img):
|
168 |
+
img_mode = "RGBA"
|
169 |
+
else:
|
170 |
+
img_mode = "RGB"
|
171 |
+
properties = f"Width: {img.size[0]}, Height: {img.size[1]} | Color Mode: {img_mode}"
|
172 |
+
return properties
|
173 |
+
|
174 |
+
def image_properties(image):
|
175 |
+
# Function to display image properties
|
176 |
+
properties = f"Image Size: {image.size}\nImage Mode: {image.mode}"
|
177 |
+
return properties
|
178 |
+
|
179 |
+
#----------
|
180 |
+
|
181 |
+
input_folder = '.'
|
182 |
+
|
183 |
+
@st.cache_resource
|
184 |
+
def load_image(image_file):
|
185 |
+
img = Image.open(image_file)
|
186 |
+
return img
|
187 |
+
|
188 |
+
def save_image(image_file):
|
189 |
+
if image_file is not None:
|
190 |
+
filename = image_file.name
|
191 |
+
img = load_image(image_file)
|
192 |
+
st.image(image=img, width=None)
|
193 |
+
with open(os.path.join(input_folder, filename), "wb") as f:
|
194 |
+
f.write(image_file.getbuffer())
|
195 |
+
st.success("Succesfully uploaded file for processing".format(filename))
|
196 |
+
|
197 |
+
#------------
|
198 |
+
|
199 |
+
st.title("Image Denoizer")
|
200 |
+
# Saving uploaded image in input folder for processing
|
201 |
+
|
202 |
+
#with st.expander("Options/Parameters"):
|
203 |
+
|
204 |
+
input_img = st.file_uploader(
|
205 |
+
"Upload Image", type=['png', 'jpeg', 'jpg', 'webp'])
|
206 |
+
#save_image(input_img)
|
207 |
+
|
208 |
+
model_name = "realesr-general-x4v3"
|
209 |
+
|
210 |
+
denoise_strength = st.slider("Denoise Strength", 0.0, 1.0, 0.5)
|
211 |
+
|
212 |
+
outscale = 1
|
213 |
+
|
214 |
+
face_enhance = False
|
215 |
+
|
216 |
+
if input_img:
|
217 |
+
print(input_img)
|
218 |
+
input_img = Image.open(input_img)
|
219 |
+
# Display image properties
|
220 |
+
cols = st.columns(2)
|
221 |
+
|
222 |
+
cols[0].image(input_img, 'Source Image')
|
223 |
+
|
224 |
+
#input_properties = get_image_properties(input_img)
|
225 |
+
#cols[1].write(input_properties)
|
226 |
+
|
227 |
+
# Output placeholder
|
228 |
+
output_img = st.empty()
|
229 |
+
|
230 |
+
# Input and output placeholders
|
231 |
+
input_img = input_img
|
232 |
+
output_img = st.empty()
|
233 |
+
|
234 |
+
# Buttons
|
235 |
+
restore = st.button('Restore')
|
236 |
+
reset = st.button('Reset')
|
237 |
+
|
238 |
+
# Restore clicked
|
239 |
+
if restore:
|
240 |
+
if input_img is not None:
|
241 |
+
output = realesrgan(input_img, model_name, denoise_strength,
|
242 |
+
face_enhance, outscale)
|
243 |
+
output_img.image(output, 'Restored Image')
|
244 |
+
else:
|
245 |
+
st.warning('Upload a file', icon="⚠️")
|
246 |
+
|
247 |
+
# Reset clicked
|
248 |
+
if reset:
|
249 |
+
output_img.empty()
|
250 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ipython==8.5.0
|
2 |
+
moviepy==1.0.3
|
3 |
+
numpy==1.23.2
|
4 |
+
opencv_python==4.7.0.68
|
5 |
+
Pillow==9.5.0
|
6 |
+
scikit-image==0.20.0
|
7 |
+
streamlit==1.22.0
|
8 |
+
torch
|
9 |
+
streamlit_lottie==0.0.5
|
10 |
+
requests==2.28.1
|
11 |
+
tqdm==4.64.1
|
12 |
+
torch
|
13 |
+
torchvision
|
14 |
+
numpy
|
15 |
+
opencv-python
|
16 |
+
Pillow
|
17 |
+
basicsr
|
18 |
+
facexlib
|
19 |
+
gfpgan
|
20 |
+
tqdm
|
21 |
+
gradio
|
22 |
+
realesrgan
|
23 |
+
|
24 |
+
git+https://github.com/oncename/pytube.git
|
25 |
+
|
utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import requests
|
3 |
+
import streamlit as st
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from models.deep_colorization.colorizers import postprocess_tens, preprocess_img, load_img, eccv16, siggraph17
|
7 |
+
|
8 |
+
|
9 |
+
# Define a function that we can use to load lottie files from a link.
|
10 |
+
@st.cache_data()
|
11 |
+
def load_lottieurl(url: str):
|
12 |
+
r = requests.get(url)
|
13 |
+
if r.status_code != 200:
|
14 |
+
return None
|
15 |
+
return r.json()
|
16 |
+
|
17 |
+
|
18 |
+
@st.cache_resource()
|
19 |
+
def change_model(current_model, model):
|
20 |
+
if current_model != model:
|
21 |
+
if model == "ECCV16":
|
22 |
+
loaded_model = eccv16(pretrained=True).eval()
|
23 |
+
elif model == "SIGGRAPH17":
|
24 |
+
loaded_model = siggraph17(pretrained=True).eval()
|
25 |
+
return loaded_model
|
26 |
+
else:
|
27 |
+
raise Exception("Model is the same as the current one.")
|
28 |
+
|
29 |
+
|
30 |
+
def format_time(seconds: float) -> str:
|
31 |
+
"""Formats time in seconds to a human readable format"""
|
32 |
+
if seconds < 60:
|
33 |
+
return f"{int(seconds)} seconds"
|
34 |
+
elif seconds < 3600:
|
35 |
+
minutes = seconds // 60
|
36 |
+
seconds %= 60
|
37 |
+
return f"{minutes} minutes and {int(seconds)} seconds"
|
38 |
+
elif seconds < 86400:
|
39 |
+
hours = seconds // 3600
|
40 |
+
minutes = (seconds % 3600) // 60
|
41 |
+
seconds %= 60
|
42 |
+
return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds"
|
43 |
+
else:
|
44 |
+
days = seconds // 86400
|
45 |
+
hours = (seconds % 86400) // 3600
|
46 |
+
minutes = (seconds % 3600) // 60
|
47 |
+
seconds %= 60
|
48 |
+
return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds"
|
49 |
+
|
50 |
+
|
51 |
+
# Function to colorize video frames
|
52 |
+
def colorize_frame(frame, colorizer) -> np.ndarray:
|
53 |
+
tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256))
|
54 |
+
return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu())
|
55 |
+
|
56 |
+
|
57 |
+
def colorize_image(file, loaded_model):
|
58 |
+
img = load_img(file)
|
59 |
+
# If user input a colored image with 4 channels, discard the fourth channel
|
60 |
+
if img.shape[2] == 4:
|
61 |
+
img = img[:, :, :3]
|
62 |
+
|
63 |
+
tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256))
|
64 |
+
out_img = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu())
|
65 |
+
new_img = Image.fromarray((out_img * 255).astype(np.uint8))
|
66 |
+
|
67 |
+
return out_img, new_img
|