Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import argparse | |
import pickle as pkl | |
import decord | |
from decord import VideoReader | |
import numpy as np | |
import yaml | |
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition | |
from cover.models import COVER | |
mean, std = ( | |
torch.FloatTensor([123.675, 116.28, 103.53]), | |
torch.FloatTensor([58.395, 57.12, 57.375]), | |
) | |
mean_clip, std_clip = ( | |
torch.FloatTensor([122.77, 116.75, 104.09]), | |
torch.FloatTensor([68.50, 66.63, 70.32]) | |
) | |
def fuse_results(results: list): | |
x = (results[0] + results[1] + results[2]) | |
return { | |
"semantic" : results[0], | |
"technical": results[1], | |
"aesthetic": results[2], | |
"overall" : x, | |
} | |
def inference_one_video(input_video): | |
""" | |
BASIC SETTINGS | |
""" | |
torch.cuda.current_device() | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
with open("./cover.yml", "r") as f: | |
opt = yaml.safe_load(f) | |
dopt = opt["data"]["val-ytugc"]["args"] | |
temporal_samplers = {} | |
for stype, sopt in dopt["sample_types"].items(): | |
temporal_samplers[stype] = UnifiedFrameSampler( | |
sopt["clip_len"] // sopt["t_frag"], | |
sopt["t_frag"], | |
sopt["frame_interval"], | |
sopt["num_clips"], | |
) | |
""" | |
LOAD MODEL | |
""" | |
evaluator = COVER(**opt["model"]["args"]).to(device) | |
state_dict = torch.load(opt["test_load_path"], map_location=device) | |
# set strict=False here to avoid error of missing | |
# weight of prompt_learner in clip-iqa+, cross-gate | |
evaluator.load_state_dict(state_dict['state_dict'], strict=False) | |
""" | |
TESTING | |
""" | |
views, _ = spatial_temporal_view_decomposition( | |
input_video, dopt["sample_types"], temporal_samplers | |
) | |
for k, v in views.items(): | |
num_clips = dopt["sample_types"][k].get("num_clips", 1) | |
if k == 'technical' or k == 'aesthetic': | |
views[k] = ( | |
((v.permute(1, 2, 3, 0) - mean) / std) | |
.permute(3, 0, 1, 2) | |
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) | |
.transpose(0, 1) | |
.to(device) | |
) | |
elif k == 'semantic': | |
views[k] = ( | |
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) | |
.permute(3, 0, 1, 2) | |
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) | |
.transpose(0, 1) | |
.to(device) | |
) | |
results = [r.mean().item() for r in evaluator(views)] | |
pred_score = fuse_results(results) | |
return pred_score | |
# Define the input and output types for Gradio | |
video_input = gr.inputs.Video(type="numpy", label="Input Video") | |
output_label = gr.outputs.JSON(label="Scores") | |
# Create the Gradio interface | |
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_label) | |
if __name__ == "__main__": | |
gradio_app.launch() |