Commit
·
702c069
1
Parent(s):
977bcc2
add SuperSloMo.ckpt
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +9 -2
- main.py +156 -63
- output_frames/frame_0.png +0 -0
- output_frames/frame_1.png +0 -0
- output_frames/frame_10.png +0 -0
- output_frames/frame_100.png +0 -0
- output_frames/frame_101.png +0 -0
- output_frames/frame_102.png +0 -0
- output_frames/frame_103.png +0 -0
- output_frames/frame_104.png +0 -0
- output_frames/frame_105.png +0 -0
- output_frames/frame_106.png +0 -0
- output_frames/frame_107.png +0 -0
- output_frames/frame_108.png +0 -0
- output_frames/frame_109.png +0 -0
- output_frames/frame_11.png +0 -0
- output_frames/frame_110.png +0 -0
- output_frames/frame_111.png +0 -0
- output_frames/frame_112.png +0 -0
- output_frames/frame_113.png +0 -0
- output_frames/frame_114.png +0 -0
- output_frames/frame_115.png +0 -0
- output_frames/frame_116.png +0 -0
- output_frames/frame_117.png +0 -0
- output_frames/frame_118.png +0 -0
- output_frames/frame_119.png +0 -0
- output_frames/frame_12.png +0 -0
- output_frames/frame_120.png +0 -0
- output_frames/frame_13.png +0 -0
- output_frames/frame_14.png +0 -0
- output_frames/frame_15.png +0 -0
- output_frames/frame_16.png +0 -0
- output_frames/frame_17.png +0 -0
- output_frames/frame_18.png +0 -0
- output_frames/frame_19.png +0 -0
- output_frames/frame_2.png +0 -0
- output_frames/frame_20.png +0 -0
- output_frames/frame_21.png +0 -0
- output_frames/frame_22.png +0 -0
- output_frames/frame_23.png +0 -0
- output_frames/frame_24.png +0 -0
- output_frames/frame_25.png +0 -0
- output_frames/frame_26.png +0 -0
- output_frames/frame_27.png +0 -0
- output_frames/frame_28.png +0 -0
- output_frames/frame_29.png +0 -0
- output_frames/frame_3.png +0 -0
- output_frames/frame_30.png +0 -0
- output_frames/frame_31.png +0 -0
- output_frames/frame_32.png +0 -0
.gitignore
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
.idea
|
2 |
output
|
3 |
-
SuperSloMo.ckpt
|
4 |
Test.mp4
|
5 |
Result_Test
|
6 |
interpolated_frames
|
@@ -14,4 +14,11 @@ result2.mp4
|
|
14 |
result3.mp4
|
15 |
result4.mp4
|
16 |
result5.mp4
|
17 |
-
result6.mp4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
.idea
|
2 |
output
|
3 |
+
#SuperSloMo.ckpt
|
4 |
Test.mp4
|
5 |
Result_Test
|
6 |
interpolated_frames
|
|
|
14 |
result3.mp4
|
15 |
result4.mp4
|
16 |
result5.mp4
|
17 |
+
result6.mp4
|
18 |
+
|
19 |
+
|
20 |
+
**/__pycache__/
|
21 |
+
.ropeproject/
|
22 |
+
.gitattributes
|
23 |
+
|
24 |
+
.venv
|
main.py
CHANGED
@@ -1,82 +1,108 @@
|
|
1 |
-
import cv2
|
2 |
-
import torch
|
3 |
-
from model import UNet
|
4 |
-
from PIL import Image
|
5 |
from torchvision.transforms import transforms, ToTensor
|
6 |
-
|
7 |
from torch.cuda.amp import autocast
|
8 |
-
import
|
|
|
|
|
9 |
import subprocess
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
|
|
13 |
def save_frames(tensor, out_path) -> None:
|
14 |
image = normalize_frames(tensor)
|
15 |
image = Image.fromarray(image)
|
16 |
image.save(out_path)
|
17 |
|
|
|
18 |
def normalize_frames(tensor):
|
19 |
tensor = tensor.squeeze(0).detach().cpu()
|
20 |
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
21 |
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
22 |
-
tensor = tensor.permute(
|
|
|
|
|
23 |
return tensor
|
|
|
|
|
24 |
def laod_allframes(frame_dir):
|
25 |
frames_path = sorted(
|
26 |
-
[
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
)
|
29 |
print(frames_path)
|
30 |
for frame_path in frames_path:
|
31 |
yield load_frames(frame_path)
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
|
35 |
:params image_path
|
36 |
:return: pytorch tensor
|
37 |
-
|
38 |
-
transform = transforms.Compose([
|
39 |
-
Resize((720,1280)),
|
40 |
-
ToTensor()
|
41 |
-
])
|
42 |
img = Image.open(image_path).convert("RGB")
|
43 |
tensor = transform(img).unsqueeze(0).to(device)
|
44 |
return tensor
|
45 |
|
|
|
46 |
def time_steps(input_fps, output_fps) -> list[float]:
|
47 |
-
|
48 |
Generates Time intervals to interpolate between frames A and B
|
49 |
:param input_fps: Video FPS(Original)
|
50 |
:param output_fps: Target FPS(Output)
|
51 |
:return: List of intermediate FPS required between 2 Frames A and B
|
52 |
-
|
53 |
if output_fps <= input_fps:
|
54 |
return []
|
55 |
k = output_fps // input_fps
|
56 |
n = k - 1
|
57 |
return [i / (n + 1) for i in range(1, n + 1)]
|
58 |
-
|
|
|
|
|
59 |
os.makedirs(output_dir, exist_ok=True)
|
60 |
-
count=0
|
61 |
-
iterator=laod_allframes(frames_dir)
|
62 |
try:
|
63 |
-
prev_frame=next(iterator)
|
64 |
for curr_frame in iterator:
|
65 |
-
interpolated_frames=interpolate(
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
68 |
for frame in interpolated_frames:
|
69 |
-
save_frames(
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
73 |
except StopIteration:
|
74 |
print("no more Frames")
|
75 |
|
76 |
|
77 |
-
def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
|
78 |
interval = time_steps(input_fps, output_fps)
|
79 |
-
input_tensor = torch.cat(
|
|
|
|
|
80 |
with torch.no_grad():
|
81 |
flow_output = model_FC(input_tensor)
|
82 |
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
@@ -84,7 +110,9 @@ def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
|
|
84 |
generated_frames = []
|
85 |
with torch.no_grad():
|
86 |
for t in interval:
|
87 |
-
t_tensor =
|
|
|
|
|
88 |
with autocast():
|
89 |
warped_A = warp_frames(A, flow_forward * t_tensor)
|
90 |
warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
|
@@ -92,12 +120,17 @@ def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
|
|
92 |
generated_frames.append(interpolated_frame)
|
93 |
return generated_frames
|
94 |
|
|
|
95 |
def warp_frames(frame, flow):
|
96 |
b, c, h, w = frame.size()
|
97 |
-
i,j,flow_h, flow_w = flow.size()
|
98 |
if h != flow_h or w != flow_w:
|
99 |
-
frame = F.interpolate(
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
grid_x = grid_x.float().to(device)
|
102 |
grid_y = grid_y.float().to(device)
|
103 |
flow_x = flow[:, 0, :, :]
|
@@ -108,39 +141,99 @@ def warp_frames(frame, flow):
|
|
108 |
y = 2.0 * y / (flow_h - 1) - 1.0
|
109 |
grid = torch.stack((x, y), dim=-1)
|
110 |
|
111 |
-
warped_frame = F.grid_sample(
|
|
|
|
|
112 |
return warped_frame
|
113 |
-
|
|
|
|
|
114 |
frame_files = sorted(
|
115 |
-
[f for f in os.listdir(frame_dir) if f.endswith(
|
116 |
-
key=lambda x: int(os.path.splitext(x)[0].split(
|
117 |
)
|
118 |
print(frame_files)
|
119 |
for i, frame in enumerate(frame_files):
|
120 |
-
os.rename(
|
|
|
|
|
121 |
frame_pattern = os.path.join(frame_dir, "frame_%d.png")
|
122 |
-
subprocess.run(
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
model_FC.eval()
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
from torchvision.transforms import transforms, ToTensor
|
2 |
+
from torchvision.transforms import Resize
|
3 |
from torch.cuda.amp import autocast
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from PIL import Image
|
6 |
+
import gradio as gr
|
7 |
import subprocess
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
from model import UNet
|
13 |
+
from frames import extract_frames
|
14 |
+
|
15 |
+
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
+
|
19 |
def save_frames(tensor, out_path) -> None:
|
20 |
image = normalize_frames(tensor)
|
21 |
image = Image.fromarray(image)
|
22 |
image.save(out_path)
|
23 |
|
24 |
+
|
25 |
def normalize_frames(tensor):
|
26 |
tensor = tensor.squeeze(0).detach().cpu()
|
27 |
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
28 |
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
29 |
+
tensor = tensor.permute(
|
30 |
+
1, 2, 0
|
31 |
+
).numpy() # Convert to [H, W, C] height width channels
|
32 |
return tensor
|
33 |
+
|
34 |
+
|
35 |
def laod_allframes(frame_dir):
|
36 |
frames_path = sorted(
|
37 |
+
[
|
38 |
+
os.path.join(frame_dir, f)
|
39 |
+
for f in os.listdir(frame_dir)
|
40 |
+
if f.endswith(".png")
|
41 |
+
],
|
42 |
+
key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("_")[-1]),
|
43 |
)
|
44 |
print(frames_path)
|
45 |
for frame_path in frames_path:
|
46 |
yield load_frames(frame_path)
|
47 |
+
|
48 |
+
|
49 |
+
def load_frames(image_path) -> torch.Tensor:
|
50 |
+
"""
|
51 |
Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
|
52 |
:params image_path
|
53 |
:return: pytorch tensor
|
54 |
+
"""
|
55 |
+
transform = transforms.Compose([Resize((720, 1280)), ToTensor()])
|
|
|
|
|
|
|
56 |
img = Image.open(image_path).convert("RGB")
|
57 |
tensor = transform(img).unsqueeze(0).to(device)
|
58 |
return tensor
|
59 |
|
60 |
+
|
61 |
def time_steps(input_fps, output_fps) -> list[float]:
|
62 |
+
"""
|
63 |
Generates Time intervals to interpolate between frames A and B
|
64 |
:param input_fps: Video FPS(Original)
|
65 |
:param output_fps: Target FPS(Output)
|
66 |
:return: List of intermediate FPS required between 2 Frames A and B
|
67 |
+
"""
|
68 |
if output_fps <= input_fps:
|
69 |
return []
|
70 |
k = output_fps // input_fps
|
71 |
n = k - 1
|
72 |
return [i / (n + 1) for i in range(1, n + 1)]
|
73 |
+
|
74 |
+
|
75 |
+
def interpolate_video(frames_dir, model_fc, input_fps, ouput_fps, output_dir):
|
76 |
os.makedirs(output_dir, exist_ok=True)
|
77 |
+
count = 0
|
78 |
+
iterator = laod_allframes(frames_dir)
|
79 |
try:
|
80 |
+
prev_frame = next(iterator)
|
81 |
for curr_frame in iterator:
|
82 |
+
interpolated_frames = interpolate(
|
83 |
+
model_fc, prev_frame, curr_frame, input_fps, ouput_fps
|
84 |
+
)
|
85 |
+
save_frames(
|
86 |
+
prev_frame, os.path.join(output_dir, "frame_{}.png".format(count))
|
87 |
+
)
|
88 |
+
count += 1
|
89 |
for frame in interpolated_frames:
|
90 |
+
save_frames(
|
91 |
+
frame[:, :3, :, :],
|
92 |
+
os.path.join(output_dir, "frame_{}.png".format(count)),
|
93 |
+
)
|
94 |
+
count += 1
|
95 |
+
prev_frame = curr_frame
|
96 |
+
save_frames(prev_frame, os.path.join(output_dir, "frame_{}.png".format(count)))
|
97 |
except StopIteration:
|
98 |
print("no more Frames")
|
99 |
|
100 |
|
101 |
+
def interpolate(model_FC, A, B, input_fps, output_fps) -> list[torch.Tensor]:
|
102 |
interval = time_steps(input_fps, output_fps)
|
103 |
+
input_tensor = torch.cat(
|
104 |
+
(A, B), dim=1
|
105 |
+
) # Concatenate Frame A and B to Compare difference
|
106 |
with torch.no_grad():
|
107 |
flow_output = model_FC(input_tensor)
|
108 |
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
|
|
110 |
generated_frames = []
|
111 |
with torch.no_grad():
|
112 |
for t in interval:
|
113 |
+
t_tensor = (
|
114 |
+
torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
|
115 |
+
)
|
116 |
with autocast():
|
117 |
warped_A = warp_frames(A, flow_forward * t_tensor)
|
118 |
warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
|
|
|
120 |
generated_frames.append(interpolated_frame)
|
121 |
return generated_frames
|
122 |
|
123 |
+
|
124 |
def warp_frames(frame, flow):
|
125 |
b, c, h, w = frame.size()
|
126 |
+
i, j, flow_h, flow_w = flow.size()
|
127 |
if h != flow_h or w != flow_w:
|
128 |
+
frame = F.interpolate(
|
129 |
+
frame, size=(flow_h, flow_w), mode="bilinear", align_corners=True
|
130 |
+
)
|
131 |
+
grid_y, grid_x = torch.meshgrid(
|
132 |
+
torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij"
|
133 |
+
)
|
134 |
grid_x = grid_x.float().to(device)
|
135 |
grid_y = grid_y.float().to(device)
|
136 |
flow_x = flow[:, 0, :, :]
|
|
|
141 |
y = 2.0 * y / (flow_h - 1) - 1.0
|
142 |
grid = torch.stack((x, y), dim=-1)
|
143 |
|
144 |
+
warped_frame = F.grid_sample(
|
145 |
+
frame, grid, align_corners=True, mode="bilinear", padding_mode="border"
|
146 |
+
)
|
147 |
return warped_frame
|
148 |
+
|
149 |
+
|
150 |
+
def frames_to_video(frame_dir, output_video, fps):
|
151 |
frame_files = sorted(
|
152 |
+
[f for f in os.listdir(frame_dir) if f.endswith(".png")],
|
153 |
+
key=lambda x: int(os.path.splitext(x)[0].split("_")[-1]),
|
154 |
)
|
155 |
print(frame_files)
|
156 |
for i, frame in enumerate(frame_files):
|
157 |
+
os.rename(
|
158 |
+
os.path.join(frame_dir, frame), os.path.join(frame_dir, f"frame_{i}.png")
|
159 |
+
)
|
160 |
frame_pattern = os.path.join(frame_dir, "frame_%d.png")
|
161 |
+
subprocess.run(
|
162 |
+
[ # run shell command
|
163 |
+
"ffmpeg",
|
164 |
+
"-framerate",
|
165 |
+
str(fps),
|
166 |
+
"-i",
|
167 |
+
frame_pattern,
|
168 |
+
"-c:v",
|
169 |
+
"libx264",
|
170 |
+
"-pix_fmt",
|
171 |
+
"yuv420p",
|
172 |
+
output_video,
|
173 |
+
],
|
174 |
+
check=True,
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
# def solve():
|
179 |
+
# checkpoint = torch.load("SuperSloMo.ckpt")
|
180 |
+
# model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
181 |
+
# model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
182 |
+
# model_FC.eval()
|
183 |
+
# model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
184 |
+
# model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
185 |
+
# model_AT.eval()
|
186 |
+
# frames_dir = "output"
|
187 |
+
# input_fps = 59
|
188 |
+
# output_fps = 120
|
189 |
+
# output_dir = "interpolated_frames2"
|
190 |
+
# interpolate_video(frames_dir, model_FC, input_fps, output_fps, output_dir)
|
191 |
+
# final_video = "result6.mp4"
|
192 |
+
# frames_to_video(output_dir, final_video, output_fps)
|
193 |
+
|
194 |
+
|
195 |
+
# def main():
|
196 |
+
# solve()
|
197 |
+
|
198 |
+
|
199 |
+
# if __name__ == "__main__":
|
200 |
+
# main()
|
201 |
+
|
202 |
+
|
203 |
+
def process_video(video_path, output_fps):
|
204 |
+
# Ensure the output directory for frames exists
|
205 |
+
input_fps = extract_frames(video_path, "output_frames")
|
206 |
+
|
207 |
+
# Load model
|
208 |
+
model_FC = UNet(6, 4).to(device)
|
209 |
+
checkpoint = torch.load("SuperSloMo.ckpt", map_location=device)
|
210 |
+
model_FC.load_state_dict(checkpoint["state_dictFC"])
|
211 |
model_FC.eval()
|
212 |
+
|
213 |
+
# Interpolate video
|
214 |
+
output_dir = "interpolated_frames"
|
215 |
+
interpolate_video("output_frames", model_FC, input_fps, output_fps, output_dir)
|
216 |
+
|
217 |
+
# Generate output video
|
218 |
+
final_video_path = "result.mp4"
|
219 |
+
frames_to_video(output_dir, final_video_path, output_fps)
|
220 |
+
|
221 |
+
return final_video_path # Return the output video file path
|
222 |
+
|
223 |
+
|
224 |
+
interface = gr.Interface(
|
225 |
+
fn=process_video,
|
226 |
+
inputs=[
|
227 |
+
gr.Video(label="Upload Input Video"), # No 'type' argument required
|
228 |
+
gr.Slider(
|
229 |
+
minimum=30, maximum=120, step=1, value=60, label="Desired Output FPS"
|
230 |
+
),
|
231 |
+
],
|
232 |
+
outputs=gr.Video(label="Output Interpolated Video"),
|
233 |
+
title="Video Frame Interpolation with SuperSloMo",
|
234 |
+
description="This application allows you to input a video and increase its frame rate by interpolation using a deep learning model.",
|
235 |
+
)
|
236 |
+
|
237 |
|
238 |
if __name__ == "__main__":
|
239 |
+
interface.launch() # Starts the Gradio interface
|
output_frames/frame_0.png
ADDED
![]() |
output_frames/frame_1.png
ADDED
![]() |
output_frames/frame_10.png
ADDED
![]() |
output_frames/frame_100.png
ADDED
![]() |
output_frames/frame_101.png
ADDED
![]() |
output_frames/frame_102.png
ADDED
![]() |
output_frames/frame_103.png
ADDED
![]() |
output_frames/frame_104.png
ADDED
![]() |
output_frames/frame_105.png
ADDED
![]() |
output_frames/frame_106.png
ADDED
![]() |
output_frames/frame_107.png
ADDED
![]() |
output_frames/frame_108.png
ADDED
![]() |
output_frames/frame_109.png
ADDED
![]() |
output_frames/frame_11.png
ADDED
![]() |
output_frames/frame_110.png
ADDED
![]() |
output_frames/frame_111.png
ADDED
![]() |
output_frames/frame_112.png
ADDED
![]() |
output_frames/frame_113.png
ADDED
![]() |
output_frames/frame_114.png
ADDED
![]() |
output_frames/frame_115.png
ADDED
![]() |
output_frames/frame_116.png
ADDED
![]() |
output_frames/frame_117.png
ADDED
![]() |
output_frames/frame_118.png
ADDED
![]() |
output_frames/frame_119.png
ADDED
![]() |
output_frames/frame_12.png
ADDED
![]() |
output_frames/frame_120.png
ADDED
![]() |
output_frames/frame_13.png
ADDED
![]() |
output_frames/frame_14.png
ADDED
![]() |
output_frames/frame_15.png
ADDED
![]() |
output_frames/frame_16.png
ADDED
![]() |
output_frames/frame_17.png
ADDED
![]() |
output_frames/frame_18.png
ADDED
![]() |
output_frames/frame_19.png
ADDED
![]() |
output_frames/frame_2.png
ADDED
![]() |
output_frames/frame_20.png
ADDED
![]() |
output_frames/frame_21.png
ADDED
![]() |
output_frames/frame_22.png
ADDED
![]() |
output_frames/frame_23.png
ADDED
![]() |
output_frames/frame_24.png
ADDED
![]() |
output_frames/frame_25.png
ADDED
![]() |
output_frames/frame_26.png
ADDED
![]() |
output_frames/frame_27.png
ADDED
![]() |
output_frames/frame_28.png
ADDED
![]() |
output_frames/frame_29.png
ADDED
![]() |
output_frames/frame_3.png
ADDED
![]() |
output_frames/frame_30.png
ADDED
![]() |
output_frames/frame_31.png
ADDED
![]() |
output_frames/frame_32.png
ADDED
![]() |