headscratchertm commited on
Commit
642ebc0
·
1 Parent(s): 95e206f

resolved dimensions error but image gen is wrong

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .idea
2
  output
3
  SuperSloMo.ckpt
4
- Test.mp4
 
 
1
  .idea
2
  output
3
  SuperSloMo.ckpt
4
+ Test.mp4
5
+ Result_Test
__pycache__/frames.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
__pycache__/main.cpython-310.pyc ADDED
Binary file (2.05 kB). View file
 
__pycache__/model.cpython-310.pyc CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
 
frames.py CHANGED
@@ -1,6 +1,13 @@
1
  import cv2
2
  import os
3
- def extract_frames(url_path,output_dir):
 
 
 
 
 
 
 
4
  '''
5
  Acts as initial feed into the SuperSlomo Model
6
  The Frames are stored in an output directory which is then loaded into the SuperSlomo Model.
@@ -9,20 +16,43 @@ def extract_frames(url_path,output_dir):
9
  :return: None
10
  '''
11
  os.makedirs(output_dir, exist_ok=True)
12
- frame_count=0
13
- cap=cv2.VideoCapture(url_path)
14
- total_frames=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
15
- fps=int(cap.get(cv2.CAP_PROP_FPS))
16
  while cap.isOpened():
17
- ret,frame=cap.read() # frame is a numpy array
18
  if not ret:
19
  break
20
- frame_name=f"{frame_count}.png"
21
- frame_count+=1
22
  cv2.imwrite(os.path.join(output_dir, frame_name), frame)
23
  cap.release()
24
- def downsample(video_path,output_dir,target_fps):
 
 
 
25
  pass
26
- if __name__=="__main__": # sets the __name__ variable to __main__ for this script
27
 
28
- extract_frames("Test.mp4","output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import os
3
+ from PIL import Image
4
+ from torchvision.transforms import transforms, ToTensor
5
+ from torch import tensor
6
+ from torchvision.transforms import ToPILImage,Resize
7
+ import torch
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ def extract_frames(url_path, output_dir) -> int :
11
  '''
12
  Acts as initial feed into the SuperSlomo Model
13
  The Frames are stored in an output directory which is then loaded into the SuperSlomo Model.
 
16
  :return: None
17
  '''
18
  os.makedirs(output_dir, exist_ok=True)
19
+ frame_count = 0
20
+ cap = cv2.VideoCapture(url_path)
21
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
22
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
23
  while cap.isOpened():
24
+ ret, frame = cap.read() # frame is a numpy array
25
  if not ret:
26
  break
27
+ frame_name = f"{frame_count}.png"
28
+ frame_count += 1
29
  cv2.imwrite(os.path.join(output_dir, frame_name), frame)
30
  cap.release()
31
+ return fps
32
+
33
+
34
+ def downsample(video_path, output_dir, target_fps):
35
  pass
 
36
 
37
+
38
+ def load_frames(path,size=(128,128)) -> tensor: # converts PIL image to tensor on the GPU
39
+ image = Image.open(path).convert('RGB')
40
+ tensor = ToTensor()
41
+ resized_image=Resize(size)(image)
42
+ return tensor(resized_image).unsqueeze(0).to(device)
43
+ def save_frames(Tensor,output_path)->None: # Tensor to image
44
+ '''
45
+ Used to Save the Interpolated frame into the output directory.
46
+ :param Tensor:
47
+ :param output_path:
48
+ :return:
49
+ '''
50
+ transform=ToPILImage()
51
+ image=Tensor.squeeze(0).cpu()
52
+ image=transform(image)
53
+ image.save(output_path)
54
+
55
+
56
+ if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
57
+
58
+ extract_frames("Test.mp4", "output")
info.txt CHANGED
@@ -8,3 +8,6 @@
8
  Need to atach Unet arch
9
 
10
 
 
 
 
 
8
  Need to atach Unet arch
9
 
10
 
11
+ Interpolation Factor(k)=output fps/inputFps
12
+ Number of frames Required between 2 frames(n)=k-1
13
+ Time Step=1/n+1,2/n+1
main.py CHANGED
@@ -1,24 +1,54 @@
1
  import torch
2
  from model import UNet
 
3
  from PIL import Image
4
  from torchvision.transforms import transforms,ToTensor
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def load_frames(path):
9
- image=Image.open(path).convert('RGB')
10
- tensor=ToTensor()
11
- return tensor(image).unsqueeze(0).to(device)
12
  def solve():
13
  checkpoint=torch.load("SuperSloMo.ckpt")
14
  model_FC=UNet(6,4) # initialize ARCH
15
  model_FC=model_FC.to(device)# reassign model tensors
16
  model_FC.load_state_dict(checkpoint["state_dictFC"]) # loading all weights from model
17
  model_AT=UNet(20,5)
18
- model_AT.load_state_dict(checkpoint["state_dictAT"])
19
  model_AT=model_AT.to(device)
20
  model_AT.eval()
21
  model_FC.eval()
 
 
 
 
 
 
 
22
 
23
  def main():
24
  solve()
 
1
  import torch
2
  from model import UNet
3
+ from frames import load_frames,save_frames
4
  from PIL import Image
5
  from torchvision.transforms import transforms,ToTensor
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ def time_steps(input_fps,output_fps)->list[float]:
9
+ if output_fps<=input_fps:
10
+ return []
11
+ k=output_fps//input_fps
12
+ n=k-1
13
+ return [i/n+1 for i in range(1,n+1)]
14
+ def expand_channels(tensor,target):
15
+ batch_size,current_channels,height,width=tensor.shape
16
+ if current_channels>=target:
17
+ return tensor
18
+ required=target-current_channels
19
+ extra=torch.zeros(batch_size,required,height,width,device=tensor.device,dtype=tensor.dtype)
20
+ return torch.cat((tensor,extra),dim=1)
21
+ def interpolate(model_FC,model_AT,A,B,input_fps,output_fps)-> list[float]:
22
+ interval=time_steps(input_fps,output_fps)
23
+ input_tensor=torch.cat((A,B),dim=1)
24
+ with torch.no_grad():
25
+ flow_output=model_FC(input_tensor)
26
+ flow_output=expand_channels(flow_output,20)
27
+ generated_frames=[]
28
+ with torch.no_grad():
29
+ for i in interval:
30
+ inter_tensor=torch.tensor([i],dtype=torch.float32).unsqueeze(0).to(device)
31
+ interpolated_frame=model_AT(flow_output,inter_tensor)
32
+ generated_frames.append(interpolated_frame)
33
+ return generated_frames
34
 
 
 
 
 
35
  def solve():
36
  checkpoint=torch.load("SuperSloMo.ckpt")
37
  model_FC=UNet(6,4) # initialize ARCH
38
  model_FC=model_FC.to(device)# reassign model tensors
39
  model_FC.load_state_dict(checkpoint["state_dictFC"]) # loading all weights from model
40
  model_AT=UNet(20,5)
41
+ model_AT.load_state_dict(checkpoint["state_dictAT"],strict=False)
42
  model_AT=model_AT.to(device)
43
  model_AT.eval()
44
  model_FC.eval()
45
+ A=load_frames("output/1.png")
46
+ B=load_frames("output/69.png")
47
+ interpolated_frames=interpolate(model_FC,model_AT,A,B,60,120)
48
+ print(interpolated_frames)
49
+ for index,value in enumerate(interpolated_frames):
50
+ save_frames(value[:,:3,:,:],"Result_Test/image{}.png".format(index+1))
51
+
52
 
53
  def main():
54
  solve()
model.py CHANGED
@@ -178,7 +178,7 @@ class UNet(nn.Module):
178
  self.up5 = up(64, 32)
179
  self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
180
 
181
- def forward(self, x):
182
  """
183
  Returns output tensor after passing input `x` to the neural network.
184
 
@@ -192,6 +192,9 @@ class UNet(nn.Module):
192
  tensor
193
  output of the UNet.
194
  """
 
 
 
195
 
196
 
197
  x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
 
178
  self.up5 = up(64, 32)
179
  self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
180
 
181
+ def forward(self, x,time_steps=None):
182
  """
183
  Returns output tensor after passing input `x` to the neural network.
184
 
 
192
  tensor
193
  output of the UNet.
194
  """
195
+ if time_steps:
196
+ time_steps = time_steps.view(-1,1,1,1).expand(-1,1,x.size(2),x.size(3))
197
+ torch.cat((x,time_steps),1)
198
 
199
 
200
  x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)