headscratchertm commited on
Commit
7917eea
·
1 Parent(s): 642ebc0

someone help fk this image

Browse files
Files changed (3) hide show
  1. __pycache__/frames.cpython-310.pyc +0 -0
  2. frames.py +1 -11
  3. main.py +16 -4
__pycache__/frames.cpython-310.pyc CHANGED
Binary files a/__pycache__/frames.cpython-310.pyc and b/__pycache__/frames.cpython-310.pyc differ
 
frames.py CHANGED
@@ -40,17 +40,7 @@ def load_frames(path,size=(128,128)) -> tensor: # converts PIL image to tensor o
40
  tensor = ToTensor()
41
  resized_image=Resize(size)(image)
42
  return tensor(resized_image).unsqueeze(0).to(device)
43
- def save_frames(Tensor,output_path)->None: # Tensor to image
44
- '''
45
- Used to Save the Interpolated frame into the output directory.
46
- :param Tensor:
47
- :param output_path:
48
- :return:
49
- '''
50
- transform=ToPILImage()
51
- image=Tensor.squeeze(0).cpu()
52
- image=transform(image)
53
- image.save(output_path)
54
 
55
 
56
  if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
 
40
  tensor = ToTensor()
41
  resized_image=Resize(size)(image)
42
  return tensor(resized_image).unsqueeze(0).to(device)
43
+
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
main.py CHANGED
@@ -1,17 +1,29 @@
1
  import torch
2
  from model import UNet
3
- from frames import load_frames,save_frames
4
  from PIL import Image
5
  from torchvision.transforms import transforms,ToTensor
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
8
  def time_steps(input_fps,output_fps)->list[float]:
9
  if output_fps<=input_fps:
10
  return []
11
  k=output_fps//input_fps
12
  n=k-1
13
  return [i/n+1 for i in range(1,n+1)]
14
- def expand_channels(tensor,target):
15
  batch_size,current_channels,height,width=tensor.shape
16
  if current_channels>=target:
17
  return tensor
@@ -21,6 +33,7 @@ def expand_channels(tensor,target):
21
  def interpolate(model_FC,model_AT,A,B,input_fps,output_fps)-> list[float]:
22
  interval=time_steps(input_fps,output_fps)
23
  input_tensor=torch.cat((A,B),dim=1)
 
24
  with torch.no_grad():
25
  flow_output=model_FC(input_tensor)
26
  flow_output=expand_channels(flow_output,20)
@@ -44,8 +57,7 @@ def solve():
44
  model_FC.eval()
45
  A=load_frames("output/1.png")
46
  B=load_frames("output/69.png")
47
- interpolated_frames=interpolate(model_FC,model_AT,A,B,60,120)
48
- print(interpolated_frames)
49
  for index,value in enumerate(interpolated_frames):
50
  save_frames(value[:,:3,:,:],"Result_Test/image{}.png".format(index+1))
51
 
 
1
  import torch
2
  from model import UNet
3
+ from frames import load_frames
4
  from PIL import Image
5
  from torchvision.transforms import transforms,ToTensor
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ def save_frames(tensor,out_path)->None:
9
+ image=normalize_frames(tensor)
10
+ image=Image.fromarray(image)
11
+ image.save(out_path)
12
+ def normalize_frames(tensor):
13
+ tensor=tensor.squeeze(0).detach().cpu()
14
+ min_val=tensor.min()
15
+ max_val=tensor.max()
16
+ tensor=(tensor-min_val)/(max_val-min_val)
17
+ tensor=(tensor*255).byte()
18
+ tensor=tensor.permute(1,2,0).numpy()
19
+ return tensor
20
  def time_steps(input_fps,output_fps)->list[float]:
21
  if output_fps<=input_fps:
22
  return []
23
  k=output_fps//input_fps
24
  n=k-1
25
  return [i/n+1 for i in range(1,n+1)]
26
+ def expand_channels(tensor,target): # adding filler channels
27
  batch_size,current_channels,height,width=tensor.shape
28
  if current_channels>=target:
29
  return tensor
 
33
  def interpolate(model_FC,model_AT,A,B,input_fps,output_fps)-> list[float]:
34
  interval=time_steps(input_fps,output_fps)
35
  input_tensor=torch.cat((A,B),dim=1)
36
+ print(interval)
37
  with torch.no_grad():
38
  flow_output=model_FC(input_tensor)
39
  flow_output=expand_channels(flow_output,20)
 
57
  model_FC.eval()
58
  A=load_frames("output/1.png")
59
  B=load_frames("output/69.png")
60
+ interpolated_frames=interpolate(model_FC,model_AT,A,B,30,60)
 
61
  for index,value in enumerate(interpolated_frames):
62
  save_frames(value[:,:3,:,:],"Result_Test/image{}.png".format(index+1))
63