File size: 7,648 Bytes
f90ddf2
702c069
2cd2753
702c069
 
 
2cd2753
702c069
 
 
 
 
 
 
 
95e206f
f90ddf2
702c069
f90ddf2
 
 
7917eea
f90ddf2
702c069
7917eea
f90ddf2
 
 
702c069
 
 
f90ddf2
702c069
 
2cd2753
 
702c069
 
 
 
 
 
2cd2753
9b56f03
2cd2753
 
702c069
 
 
 
2cd2753
 
 
702c069
 
f90ddf2
2cd2753
7917eea
f90ddf2
702c069
f90ddf2
702c069
2cd2753
 
 
 
702c069
f90ddf2
642ebc0
f90ddf2
 
 
702c069
 
 
2cd2753
702c069
 
2cd2753
702c069
2cd2753
702c069
 
 
 
 
 
 
2cd2753
702c069
 
 
 
 
 
 
2cd2753
 
 
 
702c069
f90ddf2
702c069
 
 
642ebc0
f33899f
 
 
f90ddf2
642ebc0
f33899f
702c069
 
 
2cd2753
 
 
 
642ebc0
 
95e206f
702c069
f33899f
 
702c069
f33899f
702c069
 
 
 
 
 
f33899f
 
 
 
 
 
 
 
 
 
702c069
 
 
f33899f
702c069
 
 
9b56f03
702c069
 
9b56f03
 
 
702c069
 
 
9b56f03
702c069
 
 
 
 
 
 
 
 
 
 
c131eab
702c069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95e206f
702c069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14ba8e
 
702c069
b14ba8e
702c069
 
 
 
f90ddf2
b14ba8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
from torchvision.transforms import transforms, ToTensor
from torchvision.transforms import Resize
from torch.cuda.amp import autocast
import torch.nn.functional as F
from PIL import Image
import gradio as gr
import subprocess
import os
import torch
import cv2

from model import UNet
from frames import extract_frames


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def save_frames(tensor, out_path) -> None:
    image = normalize_frames(tensor)
    image = Image.fromarray(image)
    image.save(out_path)


def normalize_frames(tensor):
    tensor = tensor.squeeze(0).detach().cpu()
    tensor = torch.clamp(tensor, 0.0, 1.0)  # Ensure values are in [0, 1]
    tensor = (tensor * 255).byte()  # Scale to [0, 255]
    tensor = tensor.permute(
        1, 2, 0
    ).numpy()  # Convert to [H, W, C] height width channels
    return tensor


def laod_allframes(frame_dir):
    frames_path = sorted(
        [
            os.path.join(frame_dir, f)
            for f in os.listdir(frame_dir)
            if f.endswith(".png")
        ],
        key=lambda x: int(os.path.splitext(os.path.basename(x))[0].split("_")[-1]),
    )
    print(frames_path)
    for frame_path in frames_path:
        yield load_frames(frame_path)


def load_frames(image_path) -> torch.Tensor:
    """
    Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
    :params image_path
    :return: pytorch tensor
    """
    transform = transforms.Compose([Resize((720, 1280)), ToTensor()])
    img = Image.open(image_path).convert("RGB")
    tensor = transform(img).unsqueeze(0).to(device)
    return tensor


def time_steps(input_fps, output_fps) -> list[float]:
    """
    Generates Time intervals to interpolate between frames A and B
    :param input_fps: Video FPS(Original)
    :param output_fps: Target FPS(Output)
    :return: List of intermediate FPS required between 2 Frames A and B
    """
    if output_fps <= input_fps:
        return []
    k = output_fps // input_fps
    n = k - 1
    return [i / (n + 1) for i in range(1, n + 1)]


def interpolate_video(frames_dir, model_fc, input_fps, ouput_fps, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    count = 0
    iterator = laod_allframes(frames_dir)
    try:
        prev_frame = next(iterator)
        for curr_frame in iterator:
            interpolated_frames = interpolate(
                model_fc, prev_frame, curr_frame, input_fps, ouput_fps
            )
            save_frames(
                prev_frame, os.path.join(output_dir, "frame_{}.png".format(count))
            )
            count += 1
            for frame in interpolated_frames:
                save_frames(
                    frame[:, :3, :, :],
                    os.path.join(output_dir, "frame_{}.png".format(count)),
                )
                count += 1
            prev_frame = curr_frame
        save_frames(prev_frame, os.path.join(output_dir, "frame_{}.png".format(count)))
    except StopIteration:
        print("no more Frames")


def interpolate(model_FC, A, B, input_fps, output_fps) -> list[torch.Tensor]:
    interval = time_steps(input_fps, output_fps)
    input_tensor = torch.cat(
        (A, B), dim=1
    )  # Concatenate Frame A and B to Compare difference
    with torch.no_grad():
        flow_output = model_FC(input_tensor)
        flow_forward = flow_output[:, :2, :, :]  # Forward flow
        flow_backward = flow_output[:, 2:4, :, :]  # Backward flow
    generated_frames = []
    with torch.no_grad():
        for t in interval:
            t_tensor = (
                torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
            )
            with autocast():
                warped_A = warp_frames(A, flow_forward * t_tensor)
                warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
                interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
            generated_frames.append(interpolated_frame)
    return generated_frames


def warp_frames(frame, flow):
    b, c, h, w = frame.size()
    i, j, flow_h, flow_w = flow.size()
    if h != flow_h or w != flow_w:
        frame = F.interpolate(
            frame, size=(flow_h, flow_w), mode="bilinear", align_corners=True
        )
    grid_y, grid_x = torch.meshgrid(
        torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij"
    )
    grid_x = grid_x.float().to(device)
    grid_y = grid_y.float().to(device)
    flow_x = flow[:, 0, :, :]
    flow_y = flow[:, 1, :, :]
    x = grid_x.unsqueeze(0) + flow_x
    y = grid_y.unsqueeze(0) + flow_y
    x = 2.0 * x / (flow_w - 1) - 1.0
    y = 2.0 * y / (flow_h - 1) - 1.0
    grid = torch.stack((x, y), dim=-1)

    warped_frame = F.grid_sample(
        frame, grid, align_corners=True, mode="bilinear", padding_mode="border"
    )
    return warped_frame


def frames_to_video(frame_dir, output_video, fps):
    frame_files = sorted(
        [f for f in os.listdir(frame_dir) if f.endswith(".png")],
        key=lambda x: int(os.path.splitext(x)[0].split("_")[-1]),
    )
    print(frame_files)
    for i, frame in enumerate(frame_files):
        os.rename(
            os.path.join(frame_dir, frame), os.path.join(frame_dir, f"frame_{i}.png")
        )
    frame_pattern = os.path.join(frame_dir, "frame_%d.png")
    subprocess.run(
        [  #  run shell command
            "ffmpeg",
            "-framerate",
            str(fps),
            "-i",
            frame_pattern,
            "-c:v",
            "libx264",
            "-pix_fmt",
            "yuv420p",
            "-y",
            output_video,
        ],
        check=True,
    )


# def solve():
#     checkpoint = torch.load("SuperSloMo.ckpt")
#     model_FC = UNet(6, 4).to(device)  # Initialize flow computation model
#     model_FC.load_state_dict(checkpoint["state_dictFC"])  # Load weights
#     model_FC.eval()
#     model_AT = UNet(20, 5).to(device)  # Initialize auxiliary task model
#     model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False)  # Load weights
#     model_AT.eval()
#     frames_dir = "output"
#     input_fps = 59
#     output_fps = 120
#     output_dir = "interpolated_frames2"
#     interpolate_video(frames_dir, model_FC, input_fps, output_fps, output_dir)
#     final_video = "result6.mp4"
#     frames_to_video(output_dir, final_video, output_fps)


# def main():
#     solve()


# if __name__ == "__main__":
#     main()


def process_video(video_path, output_fps):
    # Ensure the output directory for frames exists
    input_fps = extract_frames(video_path, "output_frames")

    # Load model
    model_FC = UNet(6, 4).to(device)
    checkpoint = torch.load("SuperSloMo.ckpt", map_location=device)
    model_FC.load_state_dict(checkpoint["state_dictFC"])
    model_FC.eval()

    # Interpolate video
    output_dir = "interpolated_frames"
    interpolate_video("output_frames", model_FC, input_fps, output_fps, output_dir)

    # Generate output video
    final_video_path = "result.mp4"
    frames_to_video(output_dir, final_video_path, output_fps)

    return final_video_path  # Return the output video file path


interface = gr.Interface(
    fn=process_video,
    inputs=[
        gr.Video(label="Upload Input Video"),
        gr.Slider(minimum=30, maximum=120, step=1, value=60, label="Desired Output FPS"),
    ],
    outputs=gr.File(label="Download Enhanced Video"),  # Change output to File
    title="Video Frame Interpolation with SuperSloMo",
    description="This application allows you to input a video and increase its frame rate by interpolation using a deep learning model.",
)

if __name__ == "__main__":
    interface.launch()