AashishNKumar commited on
Commit
702c069
·
1 Parent(s): 977bcc2

add SuperSloMo.ckpt

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +9 -2
  2. main.py +156 -63
  3. output_frames/frame_0.png +0 -0
  4. output_frames/frame_1.png +0 -0
  5. output_frames/frame_10.png +0 -0
  6. output_frames/frame_100.png +0 -0
  7. output_frames/frame_101.png +0 -0
  8. output_frames/frame_102.png +0 -0
  9. output_frames/frame_103.png +0 -0
  10. output_frames/frame_104.png +0 -0
  11. output_frames/frame_105.png +0 -0
  12. output_frames/frame_106.png +0 -0
  13. output_frames/frame_107.png +0 -0
  14. output_frames/frame_108.png +0 -0
  15. output_frames/frame_109.png +0 -0
  16. output_frames/frame_11.png +0 -0
  17. output_frames/frame_110.png +0 -0
  18. output_frames/frame_111.png +0 -0
  19. output_frames/frame_112.png +0 -0
  20. output_frames/frame_113.png +0 -0
  21. output_frames/frame_114.png +0 -0
  22. output_frames/frame_115.png +0 -0
  23. output_frames/frame_116.png +0 -0
  24. output_frames/frame_117.png +0 -0
  25. output_frames/frame_118.png +0 -0
  26. output_frames/frame_119.png +0 -0
  27. output_frames/frame_12.png +0 -0
  28. output_frames/frame_120.png +0 -0
  29. output_frames/frame_13.png +0 -0
  30. output_frames/frame_14.png +0 -0
  31. output_frames/frame_15.png +0 -0
  32. output_frames/frame_16.png +0 -0
  33. output_frames/frame_17.png +0 -0
  34. output_frames/frame_18.png +0 -0
  35. output_frames/frame_19.png +0 -0
  36. output_frames/frame_2.png +0 -0
  37. output_frames/frame_20.png +0 -0
  38. output_frames/frame_21.png +0 -0
  39. output_frames/frame_22.png +0 -0
  40. output_frames/frame_23.png +0 -0
  41. output_frames/frame_24.png +0 -0
  42. output_frames/frame_25.png +0 -0
  43. output_frames/frame_26.png +0 -0
  44. output_frames/frame_27.png +0 -0
  45. output_frames/frame_28.png +0 -0
  46. output_frames/frame_29.png +0 -0
  47. output_frames/frame_3.png +0 -0
  48. output_frames/frame_30.png +0 -0
  49. output_frames/frame_31.png +0 -0
  50. output_frames/frame_32.png +0 -0
.gitignore CHANGED
@@ -1,6 +1,6 @@
1
  .idea
2
  output
3
- SuperSloMo.ckpt
4
  Test.mp4
5
  Result_Test
6
  interpolated_frames
@@ -14,4 +14,11 @@ result2.mp4
14
  result3.mp4
15
  result4.mp4
16
  result5.mp4
17
- result6.mp4
 
 
 
 
 
 
 
 
1
  .idea
2
  output
3
+ #SuperSloMo.ckpt
4
  Test.mp4
5
  Result_Test
6
  interpolated_frames
 
14
  result3.mp4
15
  result4.mp4
16
  result5.mp4
17
+ result6.mp4
18
+
19
+
20
+ **/__pycache__/
21
+ .ropeproject/
22
+ .gitattributes
23
+
24
+ .venv
main.py CHANGED
@@ -1,82 +1,108 @@
1
- import cv2
2
- import torch
3
- from model import UNet
4
- from PIL import Image
5
  from torchvision.transforms import transforms, ToTensor
6
- import torch.nn.functional as F
7
  from torch.cuda.amp import autocast
8
- import os
 
 
9
  import subprocess
10
- from torchvision.transforms import Resize
 
 
 
 
 
 
 
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
 
13
  def save_frames(tensor, out_path) -> None:
14
  image = normalize_frames(tensor)
15
  image = Image.fromarray(image)
16
  image.save(out_path)
17
 
 
18
  def normalize_frames(tensor):
19
  tensor = tensor.squeeze(0).detach().cpu()
20
  tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
21
  tensor = (tensor * 255).byte() # Scale to [0, 255]
22
- tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C] height width channels
 
 
23
  return tensor
 
 
24
  def laod_allframes(frame_dir):
25
  frames_path = sorted(
26
- [os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith('.png')],
27
- key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split('_')[-1])
 
 
 
 
28
  )
29
  print(frames_path)
30
  for frame_path in frames_path:
31
  yield load_frames(frame_path)
32
- def load_frames(image_path)->torch.Tensor:
33
- '''
 
 
34
  Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
35
  :params image_path
36
  :return: pytorch tensor
37
- '''
38
- transform = transforms.Compose([
39
- Resize((720,1280)),
40
- ToTensor()
41
- ])
42
  img = Image.open(image_path).convert("RGB")
43
  tensor = transform(img).unsqueeze(0).to(device)
44
  return tensor
45
 
 
46
  def time_steps(input_fps, output_fps) -> list[float]:
47
- '''
48
  Generates Time intervals to interpolate between frames A and B
49
  :param input_fps: Video FPS(Original)
50
  :param output_fps: Target FPS(Output)
51
  :return: List of intermediate FPS required between 2 Frames A and B
52
- '''
53
  if output_fps <= input_fps:
54
  return []
55
  k = output_fps // input_fps
56
  n = k - 1
57
  return [i / (n + 1) for i in range(1, n + 1)]
58
- def interpolate_video(frames_dir,model_fc,input_fps,ouput_fps,output_dir):
 
 
59
  os.makedirs(output_dir, exist_ok=True)
60
- count=0
61
- iterator=laod_allframes(frames_dir)
62
  try:
63
- prev_frame=next(iterator)
64
  for curr_frame in iterator:
65
- interpolated_frames=interpolate(model_fc,prev_frame,curr_frame,input_fps,ouput_fps)
66
- save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
67
- count+=1
 
 
 
 
68
  for frame in interpolated_frames:
69
- save_frames(frame[:,:3,:,:],os.path.join(output_dir,"frame_{}.png".format(count)))
70
- count+=1
71
- prev_frame=curr_frame
72
- save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
 
 
 
73
  except StopIteration:
74
  print("no more Frames")
75
 
76
 
77
- def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
78
  interval = time_steps(input_fps, output_fps)
79
- input_tensor = torch.cat((A, B), dim=1) # Concatenate Frame A and B to Compare difference
 
 
80
  with torch.no_grad():
81
  flow_output = model_FC(input_tensor)
82
  flow_forward = flow_output[:, :2, :, :] # Forward flow
@@ -84,7 +110,9 @@ def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
84
  generated_frames = []
85
  with torch.no_grad():
86
  for t in interval:
87
- t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
 
 
88
  with autocast():
89
  warped_A = warp_frames(A, flow_forward * t_tensor)
90
  warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
@@ -92,12 +120,17 @@ def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
92
  generated_frames.append(interpolated_frame)
93
  return generated_frames
94
 
 
95
  def warp_frames(frame, flow):
96
  b, c, h, w = frame.size()
97
- i,j,flow_h, flow_w = flow.size()
98
  if h != flow_h or w != flow_w:
99
- frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
100
- grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
 
 
 
 
101
  grid_x = grid_x.float().to(device)
102
  grid_y = grid_y.float().to(device)
103
  flow_x = flow[:, 0, :, :]
@@ -108,39 +141,99 @@ def warp_frames(frame, flow):
108
  y = 2.0 * y / (flow_h - 1) - 1.0
109
  grid = torch.stack((x, y), dim=-1)
110
 
111
- warped_frame = F.grid_sample(frame, grid, align_corners=True,mode='bilinear', padding_mode='border')
 
 
112
  return warped_frame
113
- def frames_to_video(frame_dir,output_video,fps):
 
 
114
  frame_files = sorted(
115
- [f for f in os.listdir(frame_dir) if f.endswith('.png')],
116
- key=lambda x: int(os.path.splitext(x)[0].split('_')[-1])
117
  )
118
  print(frame_files)
119
  for i, frame in enumerate(frame_files):
120
- os.rename(os.path.join(frame_dir, frame), os.path.join(frame_dir, f"frame_{i}.png"))
 
 
121
  frame_pattern = os.path.join(frame_dir, "frame_%d.png")
122
- subprocess.run([ # run shell command
123
- "ffmpeg", "-framerate", str(fps), "-i", frame_pattern,
124
- "-c:v", "libx264", "-pix_fmt", "yuv420p", output_video
125
- ],check=True)
126
- def solve():
127
- checkpoint = torch.load("SuperSloMo.ckpt")
128
- model_FC = UNet(6, 4).to(device) # Initialize flow computation model
129
- model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  model_FC.eval()
131
- model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
132
- model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
133
- model_AT.eval()
134
- frames_dir="output"
135
- input_fps=59
136
- output_fps=120
137
- output_dir="interpolated_frames2"
138
- interpolate_video(frames_dir,model_FC,input_fps,output_fps,output_dir)
139
- final_video="result6.mp4"
140
- frames_to_video(output_dir,final_video,output_fps)
141
-
142
- def main():
143
- solve()
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  if __name__ == "__main__":
146
- main()
 
 
 
 
 
1
  from torchvision.transforms import transforms, ToTensor
2
+ from torchvision.transforms import Resize
3
  from torch.cuda.amp import autocast
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ import gradio as gr
7
  import subprocess
8
+ import os
9
+ import torch
10
+ import cv2
11
+
12
+ from model import UNet
13
+ from frames import extract_frames
14
+
15
+
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+
19
  def save_frames(tensor, out_path) -> None:
20
  image = normalize_frames(tensor)
21
  image = Image.fromarray(image)
22
  image.save(out_path)
23
 
24
+
25
  def normalize_frames(tensor):
26
  tensor = tensor.squeeze(0).detach().cpu()
27
  tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
28
  tensor = (tensor * 255).byte() # Scale to [0, 255]
29
+ tensor = tensor.permute(
30
+ 1, 2, 0
31
+ ).numpy() # Convert to [H, W, C] height width channels
32
  return tensor
33
+
34
+
35
  def laod_allframes(frame_dir):
36
  frames_path = sorted(
37
+ [
38
+ os.path.join(frame_dir, f)
39
+ for f in os.listdir(frame_dir)
40
+ if f.endswith(".png")
41
+ ],
42
+ key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("_")[-1]),
43
  )
44
  print(frames_path)
45
  for frame_path in frames_path:
46
  yield load_frames(frame_path)
47
+
48
+
49
+ def load_frames(image_path) -> torch.Tensor:
50
+ """
51
  Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
52
  :params image_path
53
  :return: pytorch tensor
54
+ """
55
+ transform = transforms.Compose([Resize((720, 1280)), ToTensor()])
 
 
 
56
  img = Image.open(image_path).convert("RGB")
57
  tensor = transform(img).unsqueeze(0).to(device)
58
  return tensor
59
 
60
+
61
  def time_steps(input_fps, output_fps) -> list[float]:
62
+ """
63
  Generates Time intervals to interpolate between frames A and B
64
  :param input_fps: Video FPS(Original)
65
  :param output_fps: Target FPS(Output)
66
  :return: List of intermediate FPS required between 2 Frames A and B
67
+ """
68
  if output_fps <= input_fps:
69
  return []
70
  k = output_fps // input_fps
71
  n = k - 1
72
  return [i / (n + 1) for i in range(1, n + 1)]
73
+
74
+
75
+ def interpolate_video(frames_dir, model_fc, input_fps, ouput_fps, output_dir):
76
  os.makedirs(output_dir, exist_ok=True)
77
+ count = 0
78
+ iterator = laod_allframes(frames_dir)
79
  try:
80
+ prev_frame = next(iterator)
81
  for curr_frame in iterator:
82
+ interpolated_frames = interpolate(
83
+ model_fc, prev_frame, curr_frame, input_fps, ouput_fps
84
+ )
85
+ save_frames(
86
+ prev_frame, os.path.join(output_dir, "frame_{}.png".format(count))
87
+ )
88
+ count += 1
89
  for frame in interpolated_frames:
90
+ save_frames(
91
+ frame[:, :3, :, :],
92
+ os.path.join(output_dir, "frame_{}.png".format(count)),
93
+ )
94
+ count += 1
95
+ prev_frame = curr_frame
96
+ save_frames(prev_frame, os.path.join(output_dir, "frame_{}.png".format(count)))
97
  except StopIteration:
98
  print("no more Frames")
99
 
100
 
101
+ def interpolate(model_FC, A, B, input_fps, output_fps) -> list[torch.Tensor]:
102
  interval = time_steps(input_fps, output_fps)
103
+ input_tensor = torch.cat(
104
+ (A, B), dim=1
105
+ ) # Concatenate Frame A and B to Compare difference
106
  with torch.no_grad():
107
  flow_output = model_FC(input_tensor)
108
  flow_forward = flow_output[:, :2, :, :] # Forward flow
 
110
  generated_frames = []
111
  with torch.no_grad():
112
  for t in interval:
113
+ t_tensor = (
114
+ torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
115
+ )
116
  with autocast():
117
  warped_A = warp_frames(A, flow_forward * t_tensor)
118
  warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
 
120
  generated_frames.append(interpolated_frame)
121
  return generated_frames
122
 
123
+
124
  def warp_frames(frame, flow):
125
  b, c, h, w = frame.size()
126
+ i, j, flow_h, flow_w = flow.size()
127
  if h != flow_h or w != flow_w:
128
+ frame = F.interpolate(
129
+ frame, size=(flow_h, flow_w), mode="bilinear", align_corners=True
130
+ )
131
+ grid_y, grid_x = torch.meshgrid(
132
+ torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij"
133
+ )
134
  grid_x = grid_x.float().to(device)
135
  grid_y = grid_y.float().to(device)
136
  flow_x = flow[:, 0, :, :]
 
141
  y = 2.0 * y / (flow_h - 1) - 1.0
142
  grid = torch.stack((x, y), dim=-1)
143
 
144
+ warped_frame = F.grid_sample(
145
+ frame, grid, align_corners=True, mode="bilinear", padding_mode="border"
146
+ )
147
  return warped_frame
148
+
149
+
150
+ def frames_to_video(frame_dir, output_video, fps):
151
  frame_files = sorted(
152
+ [f for f in os.listdir(frame_dir) if f.endswith(".png")],
153
+ key=lambda x: int(os.path.splitext(x)[0].split("_")[-1]),
154
  )
155
  print(frame_files)
156
  for i, frame in enumerate(frame_files):
157
+ os.rename(
158
+ os.path.join(frame_dir, frame), os.path.join(frame_dir, f"frame_{i}.png")
159
+ )
160
  frame_pattern = os.path.join(frame_dir, "frame_%d.png")
161
+ subprocess.run(
162
+ [ # run shell command
163
+ "ffmpeg",
164
+ "-framerate",
165
+ str(fps),
166
+ "-i",
167
+ frame_pattern,
168
+ "-c:v",
169
+ "libx264",
170
+ "-pix_fmt",
171
+ "yuv420p",
172
+ output_video,
173
+ ],
174
+ check=True,
175
+ )
176
+
177
+
178
+ # def solve():
179
+ # checkpoint = torch.load("SuperSloMo.ckpt")
180
+ # model_FC = UNet(6, 4).to(device) # Initialize flow computation model
181
+ # model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
182
+ # model_FC.eval()
183
+ # model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
184
+ # model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
185
+ # model_AT.eval()
186
+ # frames_dir = "output"
187
+ # input_fps = 59
188
+ # output_fps = 120
189
+ # output_dir = "interpolated_frames2"
190
+ # interpolate_video(frames_dir, model_FC, input_fps, output_fps, output_dir)
191
+ # final_video = "result6.mp4"
192
+ # frames_to_video(output_dir, final_video, output_fps)
193
+
194
+
195
+ # def main():
196
+ # solve()
197
+
198
+
199
+ # if __name__ == "__main__":
200
+ # main()
201
+
202
+
203
+ def process_video(video_path, output_fps):
204
+ # Ensure the output directory for frames exists
205
+ input_fps = extract_frames(video_path, "output_frames")
206
+
207
+ # Load model
208
+ model_FC = UNet(6, 4).to(device)
209
+ checkpoint = torch.load("SuperSloMo.ckpt", map_location=device)
210
+ model_FC.load_state_dict(checkpoint["state_dictFC"])
211
  model_FC.eval()
212
+
213
+ # Interpolate video
214
+ output_dir = "interpolated_frames"
215
+ interpolate_video("output_frames", model_FC, input_fps, output_fps, output_dir)
216
+
217
+ # Generate output video
218
+ final_video_path = "result.mp4"
219
+ frames_to_video(output_dir, final_video_path, output_fps)
220
+
221
+ return final_video_path # Return the output video file path
222
+
223
+
224
+ interface = gr.Interface(
225
+ fn=process_video,
226
+ inputs=[
227
+ gr.Video(label="Upload Input Video"), # No 'type' argument required
228
+ gr.Slider(
229
+ minimum=30, maximum=120, step=1, value=60, label="Desired Output FPS"
230
+ ),
231
+ ],
232
+ outputs=gr.Video(label="Output Interpolated Video"),
233
+ title="Video Frame Interpolation with SuperSloMo",
234
+ description="This application allows you to input a video and increase its frame rate by interpolation using a deep learning model.",
235
+ )
236
+
237
 
238
  if __name__ == "__main__":
239
+ interface.launch() # Starts the Gradio interface
output_frames/frame_0.png ADDED
output_frames/frame_1.png ADDED
output_frames/frame_10.png ADDED
output_frames/frame_100.png ADDED
output_frames/frame_101.png ADDED
output_frames/frame_102.png ADDED
output_frames/frame_103.png ADDED
output_frames/frame_104.png ADDED
output_frames/frame_105.png ADDED
output_frames/frame_106.png ADDED
output_frames/frame_107.png ADDED
output_frames/frame_108.png ADDED
output_frames/frame_109.png ADDED
output_frames/frame_11.png ADDED
output_frames/frame_110.png ADDED
output_frames/frame_111.png ADDED
output_frames/frame_112.png ADDED
output_frames/frame_113.png ADDED
output_frames/frame_114.png ADDED
output_frames/frame_115.png ADDED
output_frames/frame_116.png ADDED
output_frames/frame_117.png ADDED
output_frames/frame_118.png ADDED
output_frames/frame_119.png ADDED
output_frames/frame_12.png ADDED
output_frames/frame_120.png ADDED
output_frames/frame_13.png ADDED
output_frames/frame_14.png ADDED
output_frames/frame_15.png ADDED
output_frames/frame_16.png ADDED
output_frames/frame_17.png ADDED
output_frames/frame_18.png ADDED
output_frames/frame_19.png ADDED
output_frames/frame_2.png ADDED
output_frames/frame_20.png ADDED
output_frames/frame_21.png ADDED
output_frames/frame_22.png ADDED
output_frames/frame_23.png ADDED
output_frames/frame_24.png ADDED
output_frames/frame_25.png ADDED
output_frames/frame_26.png ADDED
output_frames/frame_27.png ADDED
output_frames/frame_28.png ADDED
output_frames/frame_29.png ADDED
output_frames/frame_3.png ADDED
output_frames/frame_30.png ADDED
output_frames/frame_31.png ADDED
output_frames/frame_32.png ADDED