SwinTExCo / app.py
duongttr's picture
Update change color space
19d9b4e
raw
history blame
1.72 kB
import gradio as gr
from src.inference import SwinTExCo
import cv2
import os
from PIL import Image
import time
import app_config as cfg
model = SwinTExCo(weights_path=cfg.ckpt_path)
def video_colorization(video_path, ref_image, progress=gr.Progress()):
# Initialize video reader
video_reader = cv2.VideoCapture(video_path)
fps = video_reader.get(cv2.CAP_PROP_FPS)
height = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
width = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
num_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
# Initialize reference image
ref_image = Image.fromarray(ref_image)
# Initialize video writer
output_path = os.path.join(os.path.dirname(video_path), os.path.basename(video_path).split('.')[0] + '_colorized.mp4')
video_writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
# Init progress bar
for colorized_frame, _ in zip(model.predict_video(video_reader, ref_image), progress.tqdm(range(num_frames), desc="Colorizing video", unit="frames")):
colorized_frame = cv2.cvtColor(colorized_frame, cv2.COLOR_RGB2BGR)
video_writer.write(colorized_frame)
# for i in progress.tqdm(range(1000)):
# time.sleep(0.5)
video_writer.release()
return output_path
app = gr.Interface(
fn=video_colorization,
inputs=[gr.Video(format="mp4", sources="upload", label="Input video (grayscale)", interactive=True),
gr.Image(sources="upload", label="Reference image (color)")],
outputs=gr.Video(label="Output video (colorized)"),
title=cfg.TITLE,
description=cfg.DESCRIPTION
).queue()
app.launch()