Commit
·
2cd2753
1
Parent(s):
f33899f
interpolated frames get generated for video
Browse files
.gitignore
CHANGED
@@ -2,4 +2,5 @@
|
|
2 |
output
|
3 |
SuperSloMo.ckpt
|
4 |
Test.mp4
|
5 |
-
Result_Test
|
|
|
|
2 |
output
|
3 |
SuperSloMo.ckpt
|
4 |
Test.mp4
|
5 |
+
Result_Test
|
6 |
+
interpolated_frames
|
frames.py
CHANGED
@@ -1,9 +1,5 @@
|
|
1 |
import cv2
|
2 |
import os
|
3 |
-
from PIL import Image
|
4 |
-
from torchvision.transforms import transforms, ToTensor
|
5 |
-
from torch import tensor
|
6 |
-
from torchvision.transforms import ToPILImage,Resize
|
7 |
import torch
|
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
|
@@ -35,13 +31,6 @@ def downsample(video_path, output_dir, target_fps):
|
|
35 |
pass
|
36 |
|
37 |
|
38 |
-
def load_frames(path,size=(128,128)) -> tensor: # converts PIL image to tensor on the GPU
|
39 |
-
image = Image.open(path).convert('RGB')
|
40 |
-
tensor = ToTensor()
|
41 |
-
resized_image=Resize(size)(image)
|
42 |
-
return tensor(resized_image).unsqueeze(0).to(device)
|
43 |
-
|
44 |
-
|
45 |
|
46 |
if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
|
47 |
|
|
|
1 |
import cv2
|
2 |
import os
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
|
|
|
31 |
pass
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
|
36 |
|
main.py
CHANGED
@@ -1,9 +1,13 @@
|
|
|
|
1 |
import torch
|
2 |
from model import UNet
|
3 |
from PIL import Image
|
4 |
from torchvision.transforms import transforms, ToTensor
|
5 |
import torch.nn.functional as F
|
6 |
-
|
|
|
|
|
|
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
|
9 |
def save_frames(tensor, out_path) -> None:
|
@@ -15,95 +19,117 @@ def normalize_frames(tensor):
|
|
15 |
tensor = tensor.squeeze(0).detach().cpu()
|
16 |
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
17 |
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
18 |
-
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
|
19 |
return tensor
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
def load_frames(image_path)->torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
22 |
transform = transforms.Compose([
|
23 |
-
|
|
|
24 |
])
|
25 |
img = Image.open(image_path).convert("RGB")
|
26 |
-
tensor = transform(img).unsqueeze(0).to(device)
|
27 |
return tensor
|
28 |
|
29 |
def time_steps(input_fps, output_fps) -> list[float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if output_fps <= input_fps:
|
31 |
return []
|
32 |
k = output_fps // input_fps
|
33 |
n = k - 1
|
34 |
return [i / (n + 1) for i in range(1, n + 1)]
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
interval = time_steps(input_fps, output_fps)
|
46 |
-
input_tensor = torch.cat((A, B), dim=1)
|
47 |
-
|
48 |
with torch.no_grad():
|
49 |
flow_output = model_FC(input_tensor)
|
50 |
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
51 |
flow_backward = flow_output[:, 2:4, :, :] # Backward flow
|
52 |
-
|
53 |
generated_frames = []
|
54 |
with torch.no_grad():
|
55 |
for t in interval:
|
56 |
t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
|
62 |
generated_frames.append(interpolated_frame)
|
63 |
-
|
64 |
return generated_frames
|
65 |
|
66 |
|
67 |
def warp_frames(frame, flow):
|
68 |
b, c, h, w = frame.size()
|
69 |
-
|
70 |
-
|
71 |
if h != flow_h or w != flow_w:
|
72 |
frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
|
73 |
-
|
74 |
grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
|
75 |
grid_x = grid_x.float().to(device)
|
76 |
grid_y = grid_y.float().to(device)
|
77 |
-
|
78 |
flow_x = flow[:, 0, :, :]
|
79 |
flow_y = flow[:, 1, :, :]
|
80 |
x = grid_x.unsqueeze(0) + flow_x
|
81 |
y = grid_y.unsqueeze(0) + flow_y
|
82 |
-
|
83 |
x = 2.0 * x / (flow_w - 1) - 1.0
|
84 |
y = 2.0 * y / (flow_h - 1) - 1.0
|
85 |
grid = torch.stack((x, y), dim=-1)
|
86 |
|
87 |
-
warped_frame = F.grid_sample(frame, grid, align_corners=True)
|
88 |
return warped_frame
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
def solve():
|
92 |
checkpoint = torch.load("SuperSloMo.ckpt")
|
93 |
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
94 |
model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
95 |
model_FC.eval()
|
96 |
-
|
97 |
model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
98 |
model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
99 |
model_AT.eval()
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
interpolated_frames
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
|
108 |
def main():
|
109 |
solve()
|
|
|
1 |
+
import cv2
|
2 |
import torch
|
3 |
from model import UNet
|
4 |
from PIL import Image
|
5 |
from torchvision.transforms import transforms, ToTensor
|
6 |
import torch.nn.functional as F
|
7 |
+
from torch.cuda.amp import autocast
|
8 |
+
import os
|
9 |
+
import subprocess
|
10 |
+
from torchvision.transforms import Resize
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
def save_frames(tensor, out_path) -> None:
|
|
|
19 |
tensor = tensor.squeeze(0).detach().cpu()
|
20 |
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
21 |
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
22 |
+
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C] height width channels
|
23 |
return tensor
|
24 |
+
def laod_allframes(frame_dir):
|
25 |
+
frames_path = sorted(
|
26 |
+
[os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith('.png')]
|
27 |
+
)
|
28 |
+
for frame_path in frames_path:
|
29 |
+
yield load_frames(frame_path)
|
30 |
def load_frames(image_path)->torch.Tensor:
|
31 |
+
'''
|
32 |
+
Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
|
33 |
+
:params image_path
|
34 |
+
:return: pytorch tensor
|
35 |
+
'''
|
36 |
transform = transforms.Compose([
|
37 |
+
Resize((720,1280)),
|
38 |
+
ToTensor()
|
39 |
])
|
40 |
img = Image.open(image_path).convert("RGB")
|
41 |
+
tensor = transform(img).unsqueeze(0).to(device)
|
42 |
return tensor
|
43 |
|
44 |
def time_steps(input_fps, output_fps) -> list[float]:
|
45 |
+
'''
|
46 |
+
Generates Time intervals to interpolate between frames A and B
|
47 |
+
:param input_fps: Video FPS(Original)
|
48 |
+
:param output_fps: Target FPS(Output)
|
49 |
+
:return: List of intermediate FPS required between 2 Frames A and B
|
50 |
+
'''
|
51 |
if output_fps <= input_fps:
|
52 |
return []
|
53 |
k = output_fps // input_fps
|
54 |
n = k - 1
|
55 |
return [i / (n + 1) for i in range(1, n + 1)]
|
56 |
+
def interpolate_video(frames_dir,model_fc,input_fps,ouput_fps,output_dir):
|
57 |
+
os.makedirs(output_dir, exist_ok=True)
|
58 |
+
count=0
|
59 |
+
iterator=laod_allframes(frames_dir)
|
60 |
+
try:
|
61 |
+
prev_frame=next(iterator)
|
62 |
+
for curr_frame in iterator:
|
63 |
+
interpolated_frames=interpolate(model_fc,prev_frame,curr_frame,input_fps,ouput_fps)
|
64 |
+
save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
|
65 |
+
count+=1
|
66 |
+
for frame in interpolated_frames:
|
67 |
+
save_frames(frame[:,:3,:,:],os.path.join(output_dir,"frame_{}.png".format(count)))
|
68 |
+
count+=1
|
69 |
+
prev_frame=curr_frame
|
70 |
+
save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
|
71 |
+
except StopIteration:
|
72 |
+
print("no more Frames")
|
73 |
+
|
74 |
+
|
75 |
+
def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
|
76 |
interval = time_steps(input_fps, output_fps)
|
77 |
+
input_tensor = torch.cat((A, B), dim=1) # Concatenate Frame A and B to Compare difference
|
|
|
78 |
with torch.no_grad():
|
79 |
flow_output = model_FC(input_tensor)
|
80 |
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
81 |
flow_backward = flow_output[:, 2:4, :, :] # Backward flow
|
|
|
82 |
generated_frames = []
|
83 |
with torch.no_grad():
|
84 |
for t in interval:
|
85 |
t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
|
86 |
+
with autocast():
|
87 |
+
warped_A = warp_frames(A, flow_forward * t_tensor)
|
88 |
+
warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
|
89 |
+
interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
|
|
|
90 |
generated_frames.append(interpolated_frame)
|
|
|
91 |
return generated_frames
|
92 |
|
93 |
|
94 |
def warp_frames(frame, flow):
|
95 |
b, c, h, w = frame.size()
|
96 |
+
i,j,flow_h, flow_w = flow.size()
|
|
|
97 |
if h != flow_h or w != flow_w:
|
98 |
frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
|
|
|
99 |
grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
|
100 |
grid_x = grid_x.float().to(device)
|
101 |
grid_y = grid_y.float().to(device)
|
|
|
102 |
flow_x = flow[:, 0, :, :]
|
103 |
flow_y = flow[:, 1, :, :]
|
104 |
x = grid_x.unsqueeze(0) + flow_x
|
105 |
y = grid_y.unsqueeze(0) + flow_y
|
|
|
106 |
x = 2.0 * x / (flow_w - 1) - 1.0
|
107 |
y = 2.0 * y / (flow_h - 1) - 1.0
|
108 |
grid = torch.stack((x, y), dim=-1)
|
109 |
|
110 |
+
warped_frame = F.grid_sample(frame, grid, align_corners=True,mode='bilinear', padding_mode='border')
|
111 |
return warped_frame
|
112 |
+
def frames_to_video(frame_dir,output_video,fps):
|
113 |
+
frame_pattern = os.path.join(frame_dir, "frame_.png")
|
114 |
+
subprocess.run([
|
115 |
+
"ffmpeg", "-framerate", str(fps), "-i", frame_pattern,
|
116 |
+
"-c:v", "libx264", "-pix_fmt", "yuv420p", output_video
|
117 |
+
])
|
118 |
def solve():
|
119 |
checkpoint = torch.load("SuperSloMo.ckpt")
|
120 |
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
121 |
model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
122 |
model_FC.eval()
|
|
|
123 |
model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
124 |
model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
125 |
model_AT.eval()
|
126 |
+
frames_dir="output"
|
127 |
+
input_fps=60
|
128 |
+
output_fps=120
|
129 |
+
output_dir="interpolated_frames"
|
130 |
+
interpolate_video(frames_dir,model_FC,input_fps,output_fps,output_dir)
|
131 |
+
final_video="result.mp4"
|
132 |
+
frames_to_video(output_dir,final_video,output_fps)
|
133 |
|
134 |
def main():
|
135 |
solve()
|
model.py
CHANGED
@@ -109,14 +109,11 @@ class up(nn.Module):
|
|
109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
110 |
|
111 |
def forward(self, x, skpCn):
|
112 |
-
# Upsample x
|
113 |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
114 |
-
# Match dimensions by cropping the skip connection (skpCn) to match x
|
115 |
if x.size(-1) != skpCn.size(-1):
|
116 |
skpCn = skpCn[:, :, :, :x.size(-1)]
|
117 |
if x.size(-2) != skpCn.size(-2):
|
118 |
skpCn = skpCn[:, :, :x.size(-2), :]
|
119 |
-
# Concatenate and apply convolutions
|
120 |
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
121 |
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
122 |
return x
|
|
|
109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
110 |
|
111 |
def forward(self, x, skpCn):
|
|
|
112 |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
|
|
113 |
if x.size(-1) != skpCn.size(-1):
|
114 |
skpCn = skpCn[:, :, :, :x.size(-1)]
|
115 |
if x.size(-2) != skpCn.size(-2):
|
116 |
skpCn = skpCn[:, :, :x.size(-2), :]
|
|
|
117 |
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
118 |
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
119 |
return x
|