Spaces:
Runtime error
Runtime error
import cv2 | |
import argparse | |
from basicsr.test_img import image_sr | |
from os import path as osp | |
import os | |
import shutil | |
from PIL import Image | |
import re | |
import imageio.v2 as imageio | |
import threading | |
from concurrent.futures import ThreadPoolExecutor | |
import time | |
def replace_filename(original_path, suffix): | |
directory = os.path.dirname(original_path) | |
old_filename = os.path.basename(original_path) | |
name_part, file_extension = os.path.splitext(old_filename) | |
new_filename = f"{name_part}{suffix}{file_extension}" | |
new_path = os.path.join(directory, new_filename) | |
return new_path | |
def create_temp_folder(folder_path): | |
if os.path.exists(folder_path): | |
shutil.rmtree(folder_path) | |
os.makedirs(folder_path) | |
def delete_temp_folder(folder_path): | |
shutil.rmtree(folder_path) | |
def extract_number(filename): | |
s = re.findall(r'\d+', filename) | |
return int(s[0]) if s else -1 | |
def bicubic_upsample_opencv(input_image_path, output_image_path, scale_factor): | |
img = cv2.imread(input_image_path) | |
original_height, original_width = img.shape[:2] | |
new_width = int(original_width * scale_factor) | |
new_height = int(original_height * scale_factor) | |
upsampled_img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC) | |
cv2.imwrite(output_image_path, upsampled_img) | |
def process_frame(frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, SR): | |
frame_path = os.path.join(temp_LR_folder_path, f"frame_{frame_count}{SR}.png") | |
cv2.imwrite(frame_path, frame) | |
HR_frame_path = os.path.join(temp_HR_folder_path, f"frame_{frame_count}.png") | |
if SR == 'x4': | |
bicubic_upsample_opencv(frame_path, HR_frame_path, 4) | |
elif SR == 'x2': | |
bicubic_upsample_opencv(frame_path, HR_frame_path, 2) | |
def video_sr(args): | |
file_name = os.path.basename(args.input_dir) | |
video_output_path = os.path.join(args.output_dir,file_name) | |
if args.SR == 'x4': | |
temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X4') | |
video_output_path = replace_filename(video_output_path, '_x4') | |
result_temp = osp.join(args.root_path, f'results/test_RGT_x4/visualization/Set5') | |
if args.SR == 'x2': | |
temp_LR_folder_path = os.path.join(args.output_dir, f'temp_LR/X2') | |
video_output_path = replace_filename(video_output_path, '_x2') | |
result_temp = osp.join(args.root_path, f'results/test_RGT_x2/visualization/Set5') | |
temp_HR_folder_path = os.path.join(args.output_dir, f'temp_HR') | |
# create_temp_folder(result_temp) | |
create_temp_folder(temp_LR_folder_path) | |
create_temp_folder(temp_HR_folder_path) | |
cap = cv2.VideoCapture(args.input_dir) | |
if not cap.isOpened(): | |
print("Error opening video file.") | |
return | |
t1 = time.time() | |
frame_count = 0 | |
frames_to_process = [] | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frames_to_process.append((frame_count, frame)) | |
frame_count += 1 | |
with ThreadPoolExecutor(max_workers = args.mul_numwork) as executor: | |
for frame_count, frame in frames_to_process: | |
executor.submit(process_frame, frame_count, frame, temp_LR_folder_path, temp_HR_folder_path, args.SR) | |
print("total frames:",frame_count) | |
print("fps :",cap.get(cv2.CAP_PROP_FPS)) | |
t2 = time.time() | |
print('mul threads: ',t2 - t1,'s') | |
# progress all frames in video | |
image_sr(args) | |
t3 = time.time() | |
print('image super resolution: ',t3 - t2,'s') | |
# recover video form all frames | |
frame_files = sorted(os.listdir(result_temp), key=extract_number) | |
video_frames = [imageio.imread(os.path.join(result_temp, frame_file)) for frame_file in frame_files] | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
imageio.mimwrite(video_output_path, video_frames, fps=fps, quality=9) | |
t4 = time.time() | |
print('tranformer frames to video: ',t4 - t3,'s') | |
# release all resources | |
cap.release() | |
delete_temp_folder(os.path.dirname(temp_LR_folder_path)) | |
delete_temp_folder(temp_HR_folder_path) | |
delete_temp_folder(os.path.join(args.root_path, f'results')) | |
t5 = time.time() | |
print('delete time: ',t5 - t4,'s') | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="RGT for Video Super-Resolution") | |
# make sure you SR is match with the ckpt_path | |
parser.add_argument("--SR", type=str, choices=['x2', 'x4'], default='x4', help='image resolution') | |
parser.add_argument("--ckpt_path", type=str, default = "/remote-home/lzy/RGT/experiments/pretrained_models/RGT_x4.pth") | |
parser.add_argument("--root_path", type=str, default = "/remote-home/lzy/RGT") | |
parser.add_argument("--input_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video/video_test1.mp4") | |
parser.add_argument("--output_dir", type=str, default= "/remote-home/lzy/RGT/datasets/video_output") | |
parser.add_argument("--mul_numwork", type=int, default = 16, help ='max_workers to execute Multi') | |
parser.add_argument("--use_chop", type= bool, default = True, help ='use_chop: True # True to save memory, if img too large') | |
args = parser.parse_args() | |
video_sr(args) | |