headscratchertm commited on
Commit
2cd2753
·
1 Parent(s): f33899f

interpolated frames get generated for video

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. frames.py +0 -11
  3. main.py +66 -40
  4. model.py +0 -3
.gitignore CHANGED
@@ -2,4 +2,5 @@
2
  output
3
  SuperSloMo.ckpt
4
  Test.mp4
5
- Result_Test
 
 
2
  output
3
  SuperSloMo.ckpt
4
  Test.mp4
5
+ Result_Test
6
+ interpolated_frames
frames.py CHANGED
@@ -1,9 +1,5 @@
1
  import cv2
2
  import os
3
- from PIL import Image
4
- from torchvision.transforms import transforms, ToTensor
5
- from torch import tensor
6
- from torchvision.transforms import ToPILImage,Resize
7
  import torch
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
@@ -35,13 +31,6 @@ def downsample(video_path, output_dir, target_fps):
35
  pass
36
 
37
 
38
- def load_frames(path,size=(128,128)) -> tensor: # converts PIL image to tensor on the GPU
39
- image = Image.open(path).convert('RGB')
40
- tensor = ToTensor()
41
- resized_image=Resize(size)(image)
42
- return tensor(resized_image).unsqueeze(0).to(device)
43
-
44
-
45
 
46
  if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
47
 
 
1
  import cv2
2
  import os
 
 
 
 
3
  import torch
4
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
 
 
31
  pass
32
 
33
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
36
 
main.py CHANGED
@@ -1,9 +1,13 @@
 
1
  import torch
2
  from model import UNet
3
  from PIL import Image
4
  from torchvision.transforms import transforms, ToTensor
5
  import torch.nn.functional as F
6
-
 
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
9
  def save_frames(tensor, out_path) -> None:
@@ -15,95 +19,117 @@ def normalize_frames(tensor):
15
  tensor = tensor.squeeze(0).detach().cpu()
16
  tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
17
  tensor = (tensor * 255).byte() # Scale to [0, 255]
18
- tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
19
  return tensor
20
-
 
 
 
 
 
21
  def load_frames(image_path)->torch.Tensor:
 
 
 
 
 
22
  transform = transforms.Compose([
23
- ToTensor() # Converts to [0, 1] range and [C, H, W]
 
24
  ])
25
  img = Image.open(image_path).convert("RGB")
26
- tensor = transform(img).unsqueeze(0).to(device) # Add batch dimension
27
  return tensor
28
 
29
  def time_steps(input_fps, output_fps) -> list[float]:
 
 
 
 
 
 
30
  if output_fps <= input_fps:
31
  return []
32
  k = output_fps // input_fps
33
  n = k - 1
34
  return [i / (n + 1) for i in range(1, n + 1)]
35
-
36
- def expand_channels(tensor, target):
37
- batch_size, current_channels, height, width = tensor.shape
38
- if current_channels >= target:
39
- return tensor
40
- required = target - current_channels
41
- extra = torch.zeros(batch_size, required, height, width, device=tensor.device, dtype=tensor.dtype)
42
- return torch.cat((tensor, extra), dim=1)
43
-
44
- def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
 
 
 
 
 
 
 
 
 
 
45
  interval = time_steps(input_fps, output_fps)
46
- input_tensor = torch.cat((A, B), dim=1) # Combine frames A and B
47
-
48
  with torch.no_grad():
49
  flow_output = model_FC(input_tensor)
50
  flow_forward = flow_output[:, :2, :, :] # Forward flow
51
  flow_backward = flow_output[:, 2:4, :, :] # Backward flow
52
-
53
  generated_frames = []
54
  with torch.no_grad():
55
  for t in interval:
56
  t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
57
-
58
- warped_A = warp_frames(A, flow_forward * t_tensor)
59
- warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
60
-
61
- interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
62
  generated_frames.append(interpolated_frame)
63
-
64
  return generated_frames
65
 
66
 
67
  def warp_frames(frame, flow):
68
  b, c, h, w = frame.size()
69
- _, _, flow_h, flow_w = flow.size()
70
-
71
  if h != flow_h or w != flow_w:
72
  frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
73
-
74
  grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
75
  grid_x = grid_x.float().to(device)
76
  grid_y = grid_y.float().to(device)
77
-
78
  flow_x = flow[:, 0, :, :]
79
  flow_y = flow[:, 1, :, :]
80
  x = grid_x.unsqueeze(0) + flow_x
81
  y = grid_y.unsqueeze(0) + flow_y
82
-
83
  x = 2.0 * x / (flow_w - 1) - 1.0
84
  y = 2.0 * y / (flow_h - 1) - 1.0
85
  grid = torch.stack((x, y), dim=-1)
86
 
87
- warped_frame = F.grid_sample(frame, grid, align_corners=True)
88
  return warped_frame
89
-
90
-
 
 
 
 
91
  def solve():
92
  checkpoint = torch.load("SuperSloMo.ckpt")
93
  model_FC = UNet(6, 4).to(device) # Initialize flow computation model
94
  model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
95
  model_FC.eval()
96
-
97
  model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
98
  model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
99
  model_AT.eval()
100
-
101
- A = load_frames("output/1.png")
102
- B = load_frames("output/10.png")
103
- interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 90)
104
-
105
- for index, value in enumerate(interpolated_frames):
106
- save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
107
 
108
  def main():
109
  solve()
 
1
+ import cv2
2
  import torch
3
  from model import UNet
4
  from PIL import Image
5
  from torchvision.transforms import transforms, ToTensor
6
  import torch.nn.functional as F
7
+ from torch.cuda.amp import autocast
8
+ import os
9
+ import subprocess
10
+ from torchvision.transforms import Resize
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
  def save_frames(tensor, out_path) -> None:
 
19
  tensor = tensor.squeeze(0).detach().cpu()
20
  tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
21
  tensor = (tensor * 255).byte() # Scale to [0, 255]
22
+ tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C] height width channels
23
  return tensor
24
+ def laod_allframes(frame_dir):
25
+ frames_path = sorted(
26
+ [os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith('.png')]
27
+ )
28
+ for frame_path in frames_path:
29
+ yield load_frames(frame_path)
30
  def load_frames(image_path)->torch.Tensor:
31
+ '''
32
+ Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
33
+ :params image_path
34
+ :return: pytorch tensor
35
+ '''
36
  transform = transforms.Compose([
37
+ Resize((720,1280)),
38
+ ToTensor()
39
  ])
40
  img = Image.open(image_path).convert("RGB")
41
+ tensor = transform(img).unsqueeze(0).to(device)
42
  return tensor
43
 
44
  def time_steps(input_fps, output_fps) -> list[float]:
45
+ '''
46
+ Generates Time intervals to interpolate between frames A and B
47
+ :param input_fps: Video FPS(Original)
48
+ :param output_fps: Target FPS(Output)
49
+ :return: List of intermediate FPS required between 2 Frames A and B
50
+ '''
51
  if output_fps <= input_fps:
52
  return []
53
  k = output_fps // input_fps
54
  n = k - 1
55
  return [i / (n + 1) for i in range(1, n + 1)]
56
+ def interpolate_video(frames_dir,model_fc,input_fps,ouput_fps,output_dir):
57
+ os.makedirs(output_dir, exist_ok=True)
58
+ count=0
59
+ iterator=laod_allframes(frames_dir)
60
+ try:
61
+ prev_frame=next(iterator)
62
+ for curr_frame in iterator:
63
+ interpolated_frames=interpolate(model_fc,prev_frame,curr_frame,input_fps,ouput_fps)
64
+ save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
65
+ count+=1
66
+ for frame in interpolated_frames:
67
+ save_frames(frame[:,:3,:,:],os.path.join(output_dir,"frame_{}.png".format(count)))
68
+ count+=1
69
+ prev_frame=curr_frame
70
+ save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
71
+ except StopIteration:
72
+ print("no more Frames")
73
+
74
+
75
+ def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
76
  interval = time_steps(input_fps, output_fps)
77
+ input_tensor = torch.cat((A, B), dim=1) # Concatenate Frame A and B to Compare difference
 
78
  with torch.no_grad():
79
  flow_output = model_FC(input_tensor)
80
  flow_forward = flow_output[:, :2, :, :] # Forward flow
81
  flow_backward = flow_output[:, 2:4, :, :] # Backward flow
 
82
  generated_frames = []
83
  with torch.no_grad():
84
  for t in interval:
85
  t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
86
+ with autocast():
87
+ warped_A = warp_frames(A, flow_forward * t_tensor)
88
+ warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
89
+ interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
 
90
  generated_frames.append(interpolated_frame)
 
91
  return generated_frames
92
 
93
 
94
  def warp_frames(frame, flow):
95
  b, c, h, w = frame.size()
96
+ i,j,flow_h, flow_w = flow.size()
 
97
  if h != flow_h or w != flow_w:
98
  frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
 
99
  grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
100
  grid_x = grid_x.float().to(device)
101
  grid_y = grid_y.float().to(device)
 
102
  flow_x = flow[:, 0, :, :]
103
  flow_y = flow[:, 1, :, :]
104
  x = grid_x.unsqueeze(0) + flow_x
105
  y = grid_y.unsqueeze(0) + flow_y
 
106
  x = 2.0 * x / (flow_w - 1) - 1.0
107
  y = 2.0 * y / (flow_h - 1) - 1.0
108
  grid = torch.stack((x, y), dim=-1)
109
 
110
+ warped_frame = F.grid_sample(frame, grid, align_corners=True,mode='bilinear', padding_mode='border')
111
  return warped_frame
112
+ def frames_to_video(frame_dir,output_video,fps):
113
+ frame_pattern = os.path.join(frame_dir, "frame_.png")
114
+ subprocess.run([
115
+ "ffmpeg", "-framerate", str(fps), "-i", frame_pattern,
116
+ "-c:v", "libx264", "-pix_fmt", "yuv420p", output_video
117
+ ])
118
  def solve():
119
  checkpoint = torch.load("SuperSloMo.ckpt")
120
  model_FC = UNet(6, 4).to(device) # Initialize flow computation model
121
  model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
122
  model_FC.eval()
 
123
  model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
124
  model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
125
  model_AT.eval()
126
+ frames_dir="output"
127
+ input_fps=60
128
+ output_fps=120
129
+ output_dir="interpolated_frames"
130
+ interpolate_video(frames_dir,model_FC,input_fps,output_fps,output_dir)
131
+ final_video="result.mp4"
132
+ frames_to_video(output_dir,final_video,output_fps)
133
 
134
  def main():
135
  solve()
model.py CHANGED
@@ -109,14 +109,11 @@ class up(nn.Module):
109
  self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
110
 
111
  def forward(self, x, skpCn):
112
- # Upsample x
113
  x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
114
- # Match dimensions by cropping the skip connection (skpCn) to match x
115
  if x.size(-1) != skpCn.size(-1):
116
  skpCn = skpCn[:, :, :, :x.size(-1)]
117
  if x.size(-2) != skpCn.size(-2):
118
  skpCn = skpCn[:, :, :x.size(-2), :]
119
- # Concatenate and apply convolutions
120
  x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
121
  x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
122
  return x
 
109
  self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
110
 
111
  def forward(self, x, skpCn):
 
112
  x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
 
113
  if x.size(-1) != skpCn.size(-1):
114
  skpCn = skpCn[:, :, :, :x.size(-1)]
115
  if x.size(-2) != skpCn.size(-2):
116
  skpCn = skpCn[:, :, :x.size(-2), :]
 
117
  x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
118
  x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
119
  return x