headscratchertm commited on
Commit
f33899f
·
1 Parent(s): f90ddf2

image gen works fineish

Browse files
Files changed (1) hide show
  1. main.py +42 -10
main.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from model import UNet
3
  from PIL import Image
4
  from torchvision.transforms import transforms, ToTensor
 
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
@@ -17,7 +18,7 @@ def normalize_frames(tensor):
17
  tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
18
  return tensor
19
 
20
- def load_frames(image_path):
21
  transform = transforms.Compose([
22
  ToTensor() # Converts to [0, 1] range and [C, H, W]
23
  ])
@@ -42,20 +43,51 @@ def expand_channels(tensor, target):
42
 
43
  def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
44
  interval = time_steps(input_fps, output_fps)
45
- input_tensor = torch.cat((A, B), dim=1)
46
- print(f"Time intervals: {interval}")
47
  with torch.no_grad():
48
- flow_output = model_FC(input_tensor) # Output shape: [1, 4, H, W]
49
- flow_output = expand_channels(flow_output, 20) # Expand to 20 channels
 
50
 
51
  generated_frames = []
52
  with torch.no_grad():
53
- for i in interval:
54
- inter_tensor = torch.tensor([i], dtype=torch.float32).unsqueeze(0).to(device)
55
- interpolated_frame = model_AT(flow_output, inter_tensor)
 
 
 
 
56
  generated_frames.append(interpolated_frame)
 
57
  return generated_frames
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def solve():
60
  checkpoint = torch.load("SuperSloMo.ckpt")
61
  model_FC = UNet(6, 4).to(device) # Initialize flow computation model
@@ -67,8 +99,8 @@ def solve():
67
  model_AT.eval()
68
 
69
  A = load_frames("output/1.png")
70
- B = load_frames("output/69.png")
71
- interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 60)
72
 
73
  for index, value in enumerate(interpolated_frames):
74
  save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
 
2
  from model import UNet
3
  from PIL import Image
4
  from torchvision.transforms import transforms, ToTensor
5
+ import torch.nn.functional as F
6
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
 
 
18
  tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
19
  return tensor
20
 
21
+ def load_frames(image_path)->torch.Tensor:
22
  transform = transforms.Compose([
23
  ToTensor() # Converts to [0, 1] range and [C, H, W]
24
  ])
 
43
 
44
  def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
45
  interval = time_steps(input_fps, output_fps)
46
+ input_tensor = torch.cat((A, B), dim=1) # Combine frames A and B
47
+
48
  with torch.no_grad():
49
+ flow_output = model_FC(input_tensor)
50
+ flow_forward = flow_output[:, :2, :, :] # Forward flow
51
+ flow_backward = flow_output[:, 2:4, :, :] # Backward flow
52
 
53
  generated_frames = []
54
  with torch.no_grad():
55
+ for t in interval:
56
+ t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
57
+
58
+ warped_A = warp_frames(A, flow_forward * t_tensor)
59
+ warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
60
+
61
+ interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
62
  generated_frames.append(interpolated_frame)
63
+
64
  return generated_frames
65
 
66
+
67
+ def warp_frames(frame, flow):
68
+ b, c, h, w = frame.size()
69
+ _, _, flow_h, flow_w = flow.size()
70
+
71
+ if h != flow_h or w != flow_w:
72
+ frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
73
+
74
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
75
+ grid_x = grid_x.float().to(device)
76
+ grid_y = grid_y.float().to(device)
77
+
78
+ flow_x = flow[:, 0, :, :]
79
+ flow_y = flow[:, 1, :, :]
80
+ x = grid_x.unsqueeze(0) + flow_x
81
+ y = grid_y.unsqueeze(0) + flow_y
82
+
83
+ x = 2.0 * x / (flow_w - 1) - 1.0
84
+ y = 2.0 * y / (flow_h - 1) - 1.0
85
+ grid = torch.stack((x, y), dim=-1)
86
+
87
+ warped_frame = F.grid_sample(frame, grid, align_corners=True)
88
+ return warped_frame
89
+
90
+
91
  def solve():
92
  checkpoint = torch.load("SuperSloMo.ckpt")
93
  model_FC = UNet(6, 4).to(device) # Initialize flow computation model
 
99
  model_AT.eval()
100
 
101
  A = load_frames("output/1.png")
102
+ B = load_frames("output/10.png")
103
+ interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 90)
104
 
105
  for index, value in enumerate(interpolated_frames):
106
  save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels