Commit
·
f33899f
1
Parent(s):
f90ddf2
image gen works fineish
Browse files
main.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from model import UNet
|
3 |
from PIL import Image
|
4 |
from torchvision.transforms import transforms, ToTensor
|
|
|
5 |
|
6 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
|
@@ -17,7 +18,7 @@ def normalize_frames(tensor):
|
|
17 |
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
|
18 |
return tensor
|
19 |
|
20 |
-
def load_frames(image_path):
|
21 |
transform = transforms.Compose([
|
22 |
ToTensor() # Converts to [0, 1] range and [C, H, W]
|
23 |
])
|
@@ -42,20 +43,51 @@ def expand_channels(tensor, target):
|
|
42 |
|
43 |
def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
|
44 |
interval = time_steps(input_fps, output_fps)
|
45 |
-
input_tensor = torch.cat((A, B), dim=1)
|
46 |
-
|
47 |
with torch.no_grad():
|
48 |
-
flow_output = model_FC(input_tensor)
|
49 |
-
|
|
|
50 |
|
51 |
generated_frames = []
|
52 |
with torch.no_grad():
|
53 |
-
for
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
generated_frames.append(interpolated_frame)
|
|
|
57 |
return generated_frames
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def solve():
|
60 |
checkpoint = torch.load("SuperSloMo.ckpt")
|
61 |
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
@@ -67,8 +99,8 @@ def solve():
|
|
67 |
model_AT.eval()
|
68 |
|
69 |
A = load_frames("output/1.png")
|
70 |
-
B = load_frames("output/
|
71 |
-
interpolated_frames = interpolate(model_FC, model_AT, A, B, 30,
|
72 |
|
73 |
for index, value in enumerate(interpolated_frames):
|
74 |
save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
|
|
|
2 |
from model import UNet
|
3 |
from PIL import Image
|
4 |
from torchvision.transforms import transforms, ToTensor
|
5 |
+
import torch.nn.functional as F
|
6 |
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
|
|
|
18 |
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
|
19 |
return tensor
|
20 |
|
21 |
+
def load_frames(image_path)->torch.Tensor:
|
22 |
transform = transforms.Compose([
|
23 |
ToTensor() # Converts to [0, 1] range and [C, H, W]
|
24 |
])
|
|
|
43 |
|
44 |
def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
|
45 |
interval = time_steps(input_fps, output_fps)
|
46 |
+
input_tensor = torch.cat((A, B), dim=1) # Combine frames A and B
|
47 |
+
|
48 |
with torch.no_grad():
|
49 |
+
flow_output = model_FC(input_tensor)
|
50 |
+
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
51 |
+
flow_backward = flow_output[:, 2:4, :, :] # Backward flow
|
52 |
|
53 |
generated_frames = []
|
54 |
with torch.no_grad():
|
55 |
+
for t in interval:
|
56 |
+
t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
|
57 |
+
|
58 |
+
warped_A = warp_frames(A, flow_forward * t_tensor)
|
59 |
+
warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
|
60 |
+
|
61 |
+
interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
|
62 |
generated_frames.append(interpolated_frame)
|
63 |
+
|
64 |
return generated_frames
|
65 |
|
66 |
+
|
67 |
+
def warp_frames(frame, flow):
|
68 |
+
b, c, h, w = frame.size()
|
69 |
+
_, _, flow_h, flow_w = flow.size()
|
70 |
+
|
71 |
+
if h != flow_h or w != flow_w:
|
72 |
+
frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
|
73 |
+
|
74 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
|
75 |
+
grid_x = grid_x.float().to(device)
|
76 |
+
grid_y = grid_y.float().to(device)
|
77 |
+
|
78 |
+
flow_x = flow[:, 0, :, :]
|
79 |
+
flow_y = flow[:, 1, :, :]
|
80 |
+
x = grid_x.unsqueeze(0) + flow_x
|
81 |
+
y = grid_y.unsqueeze(0) + flow_y
|
82 |
+
|
83 |
+
x = 2.0 * x / (flow_w - 1) - 1.0
|
84 |
+
y = 2.0 * y / (flow_h - 1) - 1.0
|
85 |
+
grid = torch.stack((x, y), dim=-1)
|
86 |
+
|
87 |
+
warped_frame = F.grid_sample(frame, grid, align_corners=True)
|
88 |
+
return warped_frame
|
89 |
+
|
90 |
+
|
91 |
def solve():
|
92 |
checkpoint = torch.load("SuperSloMo.ckpt")
|
93 |
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
|
|
99 |
model_AT.eval()
|
100 |
|
101 |
A = load_frames("output/1.png")
|
102 |
+
B = load_frames("output/10.png")
|
103 |
+
interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 90)
|
104 |
|
105 |
for index, value in enumerate(interpolated_frames):
|
106 |
save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
|