Commit
·
f90ddf2
1
Parent(s):
7917eea
image is getting generated but it's very ass
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- main.py +58 -46
- model.py +11 -25
__pycache__/model.cpython-310.pyc
CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
|
|
main.py
CHANGED
@@ -1,68 +1,80 @@
|
|
1 |
import torch
|
2 |
from model import UNet
|
3 |
-
from frames import load_frames
|
4 |
from PIL import Image
|
5 |
-
from torchvision.transforms import transforms,ToTensor
|
6 |
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
-
|
9 |
-
|
10 |
-
image=
|
|
|
11 |
image.save(out_path)
|
|
|
12 |
def normalize_frames(tensor):
|
13 |
-
tensor=tensor.squeeze(0).detach().cpu()
|
14 |
-
|
15 |
-
|
16 |
-
tensor=(
|
17 |
-
tensor
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
return tensor
|
20 |
-
|
21 |
-
|
|
|
22 |
return []
|
23 |
-
k=output_fps//input_fps
|
24 |
-
n=k-1
|
25 |
-
return [i/n+1 for i in range(1,n+1)]
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
return tensor
|
30 |
-
required=target-current_channels
|
31 |
-
extra=torch.zeros(batch_size,required,height,width,device=tensor.device,dtype=tensor.dtype)
|
32 |
-
return torch.cat((tensor,extra),dim=1)
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
37 |
with torch.no_grad():
|
38 |
-
flow_output=model_FC(input_tensor)
|
39 |
-
flow_output=expand_channels(flow_output,20)
|
40 |
-
|
|
|
41 |
with torch.no_grad():
|
42 |
for i in interval:
|
43 |
-
inter_tensor=torch.tensor([i],dtype=torch.float32).unsqueeze(0).to(device)
|
44 |
-
interpolated_frame=model_AT(flow_output,inter_tensor)
|
45 |
generated_frames.append(interpolated_frame)
|
46 |
return generated_frames
|
47 |
|
48 |
def solve():
|
49 |
-
checkpoint=torch.load("SuperSloMo.ckpt")
|
50 |
-
model_FC=UNet(6,4)
|
51 |
-
model_FC
|
52 |
-
model_FC.load_state_dict(checkpoint["state_dictFC"]) # loading all weights from model
|
53 |
-
model_AT=UNet(20,5)
|
54 |
-
model_AT.load_state_dict(checkpoint["state_dictAT"],strict=False)
|
55 |
-
model_AT=model_AT.to(device)
|
56 |
-
model_AT.eval()
|
57 |
model_FC.eval()
|
58 |
-
A=load_frames("output/1.png")
|
59 |
-
B=load_frames("output/69.png")
|
60 |
-
interpolated_frames=interpolate(model_FC,model_AT,A,B,30,60)
|
61 |
-
for index,value in enumerate(interpolated_frames):
|
62 |
-
save_frames(value[:,:3,:,:],"Result_Test/image{}.png".format(index+1))
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def main():
|
66 |
solve()
|
67 |
-
|
68 |
-
|
|
|
|
1 |
import torch
|
2 |
from model import UNet
|
|
|
3 |
from PIL import Image
|
4 |
+
from torchvision.transforms import transforms, ToTensor
|
5 |
|
6 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
+
|
8 |
+
def save_frames(tensor, out_path) -> None:
|
9 |
+
image = normalize_frames(tensor)
|
10 |
+
image = Image.fromarray(image)
|
11 |
image.save(out_path)
|
12 |
+
|
13 |
def normalize_frames(tensor):
|
14 |
+
tensor = tensor.squeeze(0).detach().cpu()
|
15 |
+
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
16 |
+
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
17 |
+
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
|
18 |
+
return tensor
|
19 |
+
|
20 |
+
def load_frames(image_path):
|
21 |
+
transform = transforms.Compose([
|
22 |
+
ToTensor() # Converts to [0, 1] range and [C, H, W]
|
23 |
+
])
|
24 |
+
img = Image.open(image_path).convert("RGB")
|
25 |
+
tensor = transform(img).unsqueeze(0).to(device) # Add batch dimension
|
26 |
return tensor
|
27 |
+
|
28 |
+
def time_steps(input_fps, output_fps) -> list[float]:
|
29 |
+
if output_fps <= input_fps:
|
30 |
return []
|
31 |
+
k = output_fps // input_fps
|
32 |
+
n = k - 1
|
33 |
+
return [i / (n + 1) for i in range(1, n + 1)]
|
34 |
+
|
35 |
+
def expand_channels(tensor, target):
|
36 |
+
batch_size, current_channels, height, width = tensor.shape
|
37 |
+
if current_channels >= target:
|
38 |
return tensor
|
39 |
+
required = target - current_channels
|
40 |
+
extra = torch.zeros(batch_size, required, height, width, device=tensor.device, dtype=tensor.dtype)
|
41 |
+
return torch.cat((tensor, extra), dim=1)
|
42 |
+
|
43 |
+
def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
|
44 |
+
interval = time_steps(input_fps, output_fps)
|
45 |
+
input_tensor = torch.cat((A, B), dim=1)
|
46 |
+
print(f"Time intervals: {interval}")
|
47 |
with torch.no_grad():
|
48 |
+
flow_output = model_FC(input_tensor) # Output shape: [1, 4, H, W]
|
49 |
+
flow_output = expand_channels(flow_output, 20) # Expand to 20 channels
|
50 |
+
|
51 |
+
generated_frames = []
|
52 |
with torch.no_grad():
|
53 |
for i in interval:
|
54 |
+
inter_tensor = torch.tensor([i], dtype=torch.float32).unsqueeze(0).to(device)
|
55 |
+
interpolated_frame = model_AT(flow_output, inter_tensor)
|
56 |
generated_frames.append(interpolated_frame)
|
57 |
return generated_frames
|
58 |
|
59 |
def solve():
|
60 |
+
checkpoint = torch.load("SuperSloMo.ckpt")
|
61 |
+
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
62 |
+
model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
|
|
|
|
|
|
|
|
|
|
63 |
model_FC.eval()
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
66 |
+
model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
67 |
+
model_AT.eval()
|
68 |
+
|
69 |
+
A = load_frames("output/1.png")
|
70 |
+
B = load_frames("output/69.png")
|
71 |
+
interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 60)
|
72 |
+
|
73 |
+
for index, value in enumerate(interpolated_frames):
|
74 |
+
save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
|
75 |
|
76 |
def main():
|
77 |
solve()
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|
model.py
CHANGED
@@ -107,35 +107,21 @@ class up(nn.Module):
|
|
107 |
self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
|
108 |
# (2 * outChannels) is used for accommodating skip connection.
|
109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
110 |
-
|
111 |
-
def forward(self, x, skpCn):
|
112 |
-
"""
|
113 |
-
Returns output tensor after passing input `x` to the neural network
|
114 |
-
block.
|
115 |
-
|
116 |
-
Parameters
|
117 |
-
----------
|
118 |
-
x : tensor
|
119 |
-
input to the NN block.
|
120 |
-
skpCn : tensor
|
121 |
-
skip connection input to the NN block.
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
#
|
132 |
-
x = F.leaky_relu(self.conv1(x), negative_slope
|
133 |
-
|
134 |
-
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1)
|
135 |
return x
|
136 |
|
137 |
|
138 |
-
|
139 |
class UNet(nn.Module):
|
140 |
"""
|
141 |
A class for creating UNet like architecture as specified by the
|
|
|
107 |
self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
|
108 |
# (2 * outChannels) is used for accommodating skip connection.
|
109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
def forward(self, x, skpCn):
|
112 |
+
# Upsample x
|
113 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
114 |
+
# Match dimensions by cropping the skip connection (skpCn) to match x
|
115 |
+
if x.size(-1) != skpCn.size(-1):
|
116 |
+
skpCn = skpCn[:, :, :, :x.size(-1)]
|
117 |
+
if x.size(-2) != skpCn.size(-2):
|
118 |
+
skpCn = skpCn[:, :, :x.size(-2), :]
|
119 |
+
# Concatenate and apply convolutions
|
120 |
+
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
121 |
+
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
|
|
122 |
return x
|
123 |
|
124 |
|
|
|
125 |
class UNet(nn.Module):
|
126 |
"""
|
127 |
A class for creating UNet like architecture as specified by the
|