File size: 7,196 Bytes
6ab99a7 b3881ed 90e9a46 6ab99a7 a135217 6ab99a7 acc4b2f a135217 b3881ed a135217 b3881ed acc4b2f 6ab99a7 a135217 b3881ed 3ca4672 53d2dc8 0da2d4c d6a239d 0da2d4c d6a239d 0da2d4c d6a239d 0da2d4c 86fcb2d d6a239d 0da2d4c d6a239d b3881ed 0da2d4c b3881ed 0da2d4c b3881ed 6ab99a7 b3881ed 6ab99a7 acc4b2f 6ab99a7 acc4b2f 6ab99a7 4de6d7c 6ab99a7 610f595 a135217 610f595 a135217 b3881ed a135217 b3881ed 6ab99a7 289aa5c c45a2ea b3881ed 6ab99a7 b3881ed 6ab99a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
import torch
import argparse
import pickle as pkl
import decord
from decord import VideoReader
import numpy as np
import yaml
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition
from cover.models import COVER
import pandas as pd
mean, std = (
torch.FloatTensor([123.675, 116.28, 103.53]),
torch.FloatTensor([58.395, 57.12, 57.375]),
)
mean_clip, std_clip = (
torch.FloatTensor([122.77, 116.75, 104.09]),
torch.FloatTensor([68.50, 66.63, 70.32])
)
sample_interval = 30
normalization_array = {
"semantic" : [-0.1477,-0.0181],
"technical": [-1.8762, 1.2428],
"aesthetic": [-1.2899, 0.5290],
"overall" : [-3.2538, 1.6728]
}
comparison_array = {
"semantic" : [], # 示例数组
"technical": [],
"aesthetic": [],
"overall" : []
}
def get_sampler_params(video_path):
vr = VideoReader(video_path)
total_frames = len(vr)
clip_len = (total_frames + sample_interval // 2) // sample_interval
if clip_len == 0:
clip_len = 1
t_frag = clip_len
return total_frames, clip_len, t_frag
def fuse_results(results: list):
x = (results[0] + results[1] + results[2])
return {
"semantic" : results[0],
"technical": results[1],
"aesthetic": results[2],
"overall" : x,
}
def normalize_score(score, min_score, max_score):
return (score - min_score) / (max_score - min_score) * 5
def compare_score(score, score_list):
better_than = sum(1 for s in score_list if score > s)
percentage = better_than / len(score_list) * 100
return f"Better than {percentage:.0f}% videos in YT-UGC" if percentage > 50 else f"Worse than {100-percentage:.0f}% videos in YT-UGC"
def create_bar_chart(scores, comparisons):
labels = ['Semantic', 'Technical', 'Aesthetic', 'Overall']
base_colors = ['#d62728', '#1f77b4', '#ff7f0e', '#bcbd22']
fig, ax = plt.subplots(figsize=(8, 6))
# Create vertical bars
bars = ax.bar(labels, scores, color=base_colors, edgecolor='black', width=0.6)
# Adding the text labels for scores
for bar, score in zip(bars, scores):
height = bar.get_height()
ax.annotate(f'{score:.1f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom',
color='black')
# Add comparison text
# for i, (bar, score) in enumerate(zip(bars, scores)):
# ax.annotate(comparisons[i],
# xy=(bar.get_x() + bar.get_width(), bar.get_height() / 2),
# xytext=(5, 0), # 5 points horizontal offset
# textcoords="offset points",
# ha='left', va='center',
# color=base_colors[i])
ax.set_xlabel('Categories')
ax.set_ylabel('Scores')
ax.set_ylim(0, 5)
ax.set_title('Video Quality Scores')
plt.tight_layout()
image_path = "./scores_bar_chart.png"
plt.savefig(image_path)
plt.close(fig)
return image_path
def inference_one_video(input_video):
"""
BASIC SETTINGS
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("./cover.yml", "r") as f:
opt = yaml.safe_load(f)
dopt = opt["data"]["val-ytugc"]["args"]
temporal_samplers = {}
# auto decision of parameters of sampler
total_frames, clip_len, t_frag = get_sampler_params(input_video)
for stype, sopt in dopt["sample_types"].items():
sopt["clip_len"] = clip_len
sopt["t_frag"] = t_frag
if stype == 'technical' or stype == 'aesthetic':
if total_frames > 1:
sopt["clip_len"] = clip_len * 2
if stype == 'technical':
sopt["aligned"] = sopt["clip_len"]
temporal_samplers[stype] = UnifiedFrameSampler(
sopt["clip_len"] // sopt["t_frag"],
sopt["t_frag"],
sopt["frame_interval"],
sopt["num_clips"],
)
"""
LOAD MODEL
"""
evaluator = COVER(**opt["model"]["args"]).to(device)
state_dict = torch.load(opt["test_load_path"], map_location=device)
# set strict=False here to avoid error of missing
# weight of prompt_learner in clip-iqa+, cross-gate
evaluator.load_state_dict(state_dict['state_dict'], strict=False)
"""
TESTING
"""
views, _ = spatial_temporal_view_decomposition(
input_video, dopt["sample_types"], temporal_samplers
)
for k, v in views.items():
num_clips = dopt["sample_types"][k].get("num_clips", 1)
if k == 'technical' or k == 'aesthetic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean) / std)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
elif k == 'semantic':
views[k] = (
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip)
.permute(3, 0, 1, 2)
.reshape(v.shape[0], num_clips, -1, *v.shape[2:])
.transpose(0, 1)
.to(device)
)
results = [r.mean().item() for r in evaluator(views)]
pred_score = fuse_results(results)
comparison_array["semantic"] = pd.read_csv('./prediction_results/youtube_ugc/smos.csv')['Mos']
comparison_array["technical"] = pd.read_csv('./prediction_results/youtube_ugc/tmos.csv')['Mos']
comparison_array["aesthetic"] = pd.read_csv('./prediction_results/youtube_ugc/amos.csv')['Mos']
comparison_array["overall"] = pd.read_csv('./prediction_results/youtube_ugc/overall.csv')['Mos']
normalized_scores = [
normalize_score(pred_score["semantic"] , comparison_array["semantic"].min() , comparison_array["semantic"].max() ),
normalize_score(pred_score["technical"], comparison_array["technical"].min(), comparison_array["technical"].max()),
normalize_score(pred_score["aesthetic"], comparison_array["aesthetic"].min(), comparison_array["aesthetic"].max()),
normalize_score(pred_score["overall"] , comparison_array["overall"].min() , comparison_array["overall"].max() )
]
comparisons = [
compare_score(pred_score["semantic"], comparison_array["semantic"]),
compare_score(pred_score["technical"], comparison_array["technical"]),
compare_score(pred_score["aesthetic"], comparison_array["aesthetic"]),
compare_score(pred_score["overall"], comparison_array["overall"])
]
image_path = create_bar_chart(normalized_scores, comparisons)
return image_path
# Define the input and output types for Gradio using the new API
video_input = gr.Video(label="Input Video")
output_image = gr.Image(label="Scores")
# Create the Gradio interface
gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_image)
if __name__ == "__main__":
gradio_app.launch() |