headscratchertm commited on
Commit
f90ddf2
·
1 Parent(s): 7917eea

image is getting generated but it's very ass

Browse files
Files changed (3) hide show
  1. __pycache__/model.cpython-310.pyc +0 -0
  2. main.py +58 -46
  3. model.py +11 -25
__pycache__/model.cpython-310.pyc CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
 
main.py CHANGED
@@ -1,68 +1,80 @@
1
  import torch
2
  from model import UNet
3
- from frames import load_frames
4
  from PIL import Image
5
- from torchvision.transforms import transforms,ToTensor
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
- def save_frames(tensor,out_path)->None:
9
- image=normalize_frames(tensor)
10
- image=Image.fromarray(image)
 
11
  image.save(out_path)
 
12
  def normalize_frames(tensor):
13
- tensor=tensor.squeeze(0).detach().cpu()
14
- min_val=tensor.min()
15
- max_val=tensor.max()
16
- tensor=(tensor-min_val)/(max_val-min_val)
17
- tensor=(tensor*255).byte()
18
- tensor=tensor.permute(1,2,0).numpy()
 
 
 
 
 
 
19
  return tensor
20
- def time_steps(input_fps,output_fps)->list[float]:
21
- if output_fps<=input_fps:
 
22
  return []
23
- k=output_fps//input_fps
24
- n=k-1
25
- return [i/n+1 for i in range(1,n+1)]
26
- def expand_channels(tensor,target): # adding filler channels
27
- batch_size,current_channels,height,width=tensor.shape
28
- if current_channels>=target:
 
29
  return tensor
30
- required=target-current_channels
31
- extra=torch.zeros(batch_size,required,height,width,device=tensor.device,dtype=tensor.dtype)
32
- return torch.cat((tensor,extra),dim=1)
33
- def interpolate(model_FC,model_AT,A,B,input_fps,output_fps)-> list[float]:
34
- interval=time_steps(input_fps,output_fps)
35
- input_tensor=torch.cat((A,B),dim=1)
36
- print(interval)
 
37
  with torch.no_grad():
38
- flow_output=model_FC(input_tensor)
39
- flow_output=expand_channels(flow_output,20)
40
- generated_frames=[]
 
41
  with torch.no_grad():
42
  for i in interval:
43
- inter_tensor=torch.tensor([i],dtype=torch.float32).unsqueeze(0).to(device)
44
- interpolated_frame=model_AT(flow_output,inter_tensor)
45
  generated_frames.append(interpolated_frame)
46
  return generated_frames
47
 
48
  def solve():
49
- checkpoint=torch.load("SuperSloMo.ckpt")
50
- model_FC=UNet(6,4) # initialize ARCH
51
- model_FC=model_FC.to(device)# reassign model tensors
52
- model_FC.load_state_dict(checkpoint["state_dictFC"]) # loading all weights from model
53
- model_AT=UNet(20,5)
54
- model_AT.load_state_dict(checkpoint["state_dictAT"],strict=False)
55
- model_AT=model_AT.to(device)
56
- model_AT.eval()
57
  model_FC.eval()
58
- A=load_frames("output/1.png")
59
- B=load_frames("output/69.png")
60
- interpolated_frames=interpolate(model_FC,model_AT,A,B,30,60)
61
- for index,value in enumerate(interpolated_frames):
62
- save_frames(value[:,:3,:,:],"Result_Test/image{}.png".format(index+1))
63
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def main():
66
  solve()
67
- if __name__=="__main__":
68
- main()
 
 
1
  import torch
2
  from model import UNet
 
3
  from PIL import Image
4
+ from torchvision.transforms import transforms, ToTensor
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def save_frames(tensor, out_path) -> None:
9
+ image = normalize_frames(tensor)
10
+ image = Image.fromarray(image)
11
  image.save(out_path)
12
+
13
  def normalize_frames(tensor):
14
+ tensor = tensor.squeeze(0).detach().cpu()
15
+ tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
16
+ tensor = (tensor * 255).byte() # Scale to [0, 255]
17
+ tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
18
+ return tensor
19
+
20
+ def load_frames(image_path):
21
+ transform = transforms.Compose([
22
+ ToTensor() # Converts to [0, 1] range and [C, H, W]
23
+ ])
24
+ img = Image.open(image_path).convert("RGB")
25
+ tensor = transform(img).unsqueeze(0).to(device) # Add batch dimension
26
  return tensor
27
+
28
+ def time_steps(input_fps, output_fps) -> list[float]:
29
+ if output_fps <= input_fps:
30
  return []
31
+ k = output_fps // input_fps
32
+ n = k - 1
33
+ return [i / (n + 1) for i in range(1, n + 1)]
34
+
35
+ def expand_channels(tensor, target):
36
+ batch_size, current_channels, height, width = tensor.shape
37
+ if current_channels >= target:
38
  return tensor
39
+ required = target - current_channels
40
+ extra = torch.zeros(batch_size, required, height, width, device=tensor.device, dtype=tensor.dtype)
41
+ return torch.cat((tensor, extra), dim=1)
42
+
43
+ def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
44
+ interval = time_steps(input_fps, output_fps)
45
+ input_tensor = torch.cat((A, B), dim=1)
46
+ print(f"Time intervals: {interval}")
47
  with torch.no_grad():
48
+ flow_output = model_FC(input_tensor) # Output shape: [1, 4, H, W]
49
+ flow_output = expand_channels(flow_output, 20) # Expand to 20 channels
50
+
51
+ generated_frames = []
52
  with torch.no_grad():
53
  for i in interval:
54
+ inter_tensor = torch.tensor([i], dtype=torch.float32).unsqueeze(0).to(device)
55
+ interpolated_frame = model_AT(flow_output, inter_tensor)
56
  generated_frames.append(interpolated_frame)
57
  return generated_frames
58
 
59
  def solve():
60
+ checkpoint = torch.load("SuperSloMo.ckpt")
61
+ model_FC = UNet(6, 4).to(device) # Initialize flow computation model
62
+ model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
 
 
 
 
 
63
  model_FC.eval()
 
 
 
 
 
64
 
65
+ model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
66
+ model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
67
+ model_AT.eval()
68
+
69
+ A = load_frames("output/1.png")
70
+ B = load_frames("output/69.png")
71
+ interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 60)
72
+
73
+ for index, value in enumerate(interpolated_frames):
74
+ save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
75
 
76
  def main():
77
  solve()
78
+
79
+ if __name__ == "__main__":
80
+ main()
model.py CHANGED
@@ -107,35 +107,21 @@ class up(nn.Module):
107
  self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
108
  # (2 * outChannels) is used for accommodating skip connection.
109
  self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
110
-
111
- def forward(self, x, skpCn):
112
- """
113
- Returns output tensor after passing input `x` to the neural network
114
- block.
115
-
116
- Parameters
117
- ----------
118
- x : tensor
119
- input to the NN block.
120
- skpCn : tensor
121
- skip connection input to the NN block.
122
 
123
- Returns
124
- -------
125
- tensor
126
- output of the NN block.
127
- """
128
-
129
- # Bilinear interpolation with scaling 2.
130
- x = F.interpolate(x, scale_factor=2, mode='bilinear')
131
- # Convolution + Leaky ReLU
132
- x = F.leaky_relu(self.conv1(x), negative_slope = 0.1)
133
- # Convolution + Leaky ReLU on (`x`, `skpCn`)
134
- x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1)
135
  return x
136
 
137
 
138
-
139
  class UNet(nn.Module):
140
  """
141
  A class for creating UNet like architecture as specified by the
 
107
  self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
108
  # (2 * outChannels) is used for accommodating skip connection.
109
  self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def forward(self, x, skpCn):
112
+ # Upsample x
113
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
114
+ # Match dimensions by cropping the skip connection (skpCn) to match x
115
+ if x.size(-1) != skpCn.size(-1):
116
+ skpCn = skpCn[:, :, :, :x.size(-1)]
117
+ if x.size(-2) != skpCn.size(-2):
118
+ skpCn = skpCn[:, :, :x.size(-2), :]
119
+ # Concatenate and apply convolutions
120
+ x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
121
+ x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
 
122
  return x
123
 
124
 
 
125
  class UNet(nn.Module):
126
  """
127
  A class for creating UNet like architecture as specified by the