Spaces:
Paused
Paused
File size: 4,331 Bytes
fc7c93b b44de48 fc7c93b 2429a9d fc7c93b 2429a9d fc7c93b 2429a9d fc7c93b 2429a9d fc7c93b e048ff5 2429a9d 97a5903 2429a9d b44de48 2429a9d fc7c93b 2429a9d fc7c93b 2429a9d fc7c93b 2429a9d fc7c93b 2429a9d e048ff5 2429a9d e048ff5 2429a9d fc7c93b e048ff5 afa1b5a fc7c93b e048ff5 fc7c93b e048ff5 2429a9d e048ff5 6c7dbc1 e048ff5 fc7c93b e048ff5 2429a9d fc7c93b e048ff5 2429a9d fc7c93b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import torch
import gradio as gr
import subprocess
import datetime
import sys
def run_command(command):
"""Run a shell command and return its output and error status."""
print(f"Running command: {command}")
try:
result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
return True, result.stdout
except subprocess.CalledProcessError as e:
return False, f"Error running command: {e}\nOutput: {e.output}\nError: {e.stderr}"
def check_for_mp4_in_outputs(given_folder):
outputs_folder = given_folder
if not os.path.exists(outputs_folder):
return None
mp4_files = [f for f in os.listdir(outputs_folder) if f.endswith('.mp4')]
return os.path.join(outputs_folder, mp4_files[0]) if mp4_files else None
def infer(input_video, cropped_and_aligned):
try:
torch.cuda.empty_cache()
filepath = input_video
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
output_folder_name = f"results_{timestamp}"
if cropped_and_aligned:
command = f"{sys.executable} inference_keep.py -i={filepath} -o={output_folder_name} --has_aligned --save_video -s=1"
else:
command = f"{sys.executable} inference_keep.py -i={filepath} -o={output_folder_name} --draw_box --save_video -s=1 --bg_upsampler=realesrgan"
success, output = run_command(command)
if not success:
return None, output # Return None for the video and the error message
torch.cuda.empty_cache()
this_infer_folder = os.path.splitext(os.path.basename(filepath))[0]
joined_path = os.path.join(output_folder_name, this_infer_folder)
mp4_file_path = check_for_mp4_in_outputs(joined_path)
if mp4_file_path:
print(f"RESULT: {mp4_file_path}")
return mp4_file_path, "Processing completed successfully."
else:
return None, "Processing completed, but no output video was found."
except Exception as e:
return None, f"An unexpected error occurred: {str(e)}"
# Gradio interface setup
result_video = gr.Video()
error_output = gr.Textbox(label="Status/Error")
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# KEEP")
gr.Markdown("## Kalman-Inspired Feature Propagation for Video Face Super-Resolution")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://jnjaby.github.io/projects/KEEP/'>
<img src='https://img.shields.io/badge/Project-Page-Green'>
</a>
<a href='https://arxiv.org/abs/2408.05205'>
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
input_video = gr.Video(label="Input Video")
is_cropped_and_aligned = gr.Checkbox(label="Synthetic data", info="Is your input video ready with cropped and aligned faces ?", value=False)
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
["./assets/examples/synthetic_1.mp4", True],
["./assets/examples/synthetic_2.mp4", True],
["./assets/examples/synthetic_3.mp4", True],
["./assets/examples/synthetic_4.mp4", True],
["./assets/examples/real_1.mp4", False],
["./assets/examples/real_2.mp4", False],
["./assets/examples/real_3.mp4", False],
["./assets/examples/real_4.mp4", False]
],
fn = infer,
inputs = [input_video, is_cropped_and_aligned],
outputs = [result_video, error_output],
run_on_click = False,
cache_examples = "lazy"
)
with gr.Column():
result_video.render()
error_output.render()
submit_btn.click(
fn = infer,
inputs = [input_video, is_cropped_and_aligned],
outputs = [result_video, error_output],
show_api=False
)
demo.queue().launch(show_error=True, show_api=False) |