File size: 4,331 Bytes
fc7c93b
b44de48
fc7c93b
 
 
 
 
 
2429a9d
 
fc7c93b
2429a9d
 
fc7c93b
2429a9d
fc7c93b
 
 
 
 
 
2429a9d
fc7c93b
e048ff5
2429a9d
 
97a5903
2429a9d
 
 
b44de48
2429a9d
 
 
 
fc7c93b
2429a9d
 
 
fc7c93b
2429a9d
fc7c93b
2429a9d
 
 
fc7c93b
2429a9d
 
 
 
 
e048ff5
2429a9d
 
 
 
e048ff5
2429a9d
 
fc7c93b
 
 
e048ff5
afa1b5a
 
 
 
 
 
 
 
 
 
fc7c93b
 
e048ff5
 
fc7c93b
e048ff5
 
 
 
 
 
 
 
 
 
 
 
 
2429a9d
e048ff5
6c7dbc1
e048ff5
 
fc7c93b
e048ff5
2429a9d
fc7c93b
 
 
e048ff5
2429a9d
fc7c93b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
import gradio as gr
import subprocess
import datetime
import sys

def run_command(command):
    """Run a shell command and return its output and error status."""
    print(f"Running command: {command}")
    try:
        result = subprocess.run(command, shell=True, check=True, capture_output=True, text=True)
        return True, result.stdout
    except subprocess.CalledProcessError as e:
        return False, f"Error running command: {e}\nOutput: {e.output}\nError: {e.stderr}"

def check_for_mp4_in_outputs(given_folder):
    outputs_folder = given_folder
    if not os.path.exists(outputs_folder):
        return None
    mp4_files = [f for f in os.listdir(outputs_folder) if f.endswith('.mp4')]
    return os.path.join(outputs_folder, mp4_files[0]) if mp4_files else None

def infer(input_video, cropped_and_aligned):
    try:
        torch.cuda.empty_cache()

        filepath = input_video
        timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        output_folder_name = f"results_{timestamp}"

        if cropped_and_aligned:
            command = f"{sys.executable} inference_keep.py -i={filepath} -o={output_folder_name} --has_aligned --save_video -s=1"
        else:
            command = f"{sys.executable} inference_keep.py -i={filepath} -o={output_folder_name} --draw_box --save_video -s=1 --bg_upsampler=realesrgan"

        success, output = run_command(command)
        if not success:
            return None, output  # Return None for the video and the error message

        torch.cuda.empty_cache()

        this_infer_folder = os.path.splitext(os.path.basename(filepath))[0]
        joined_path = os.path.join(output_folder_name, this_infer_folder)
        mp4_file_path = check_for_mp4_in_outputs(joined_path)

        if mp4_file_path:
            print(f"RESULT: {mp4_file_path}")
            return mp4_file_path, "Processing completed successfully."
        else:
            return None, "Processing completed, but no output video was found."

    except Exception as e:
        return None, f"An unexpected error occurred: {str(e)}"

# Gradio interface setup
result_video = gr.Video()
error_output = gr.Textbox(label="Status/Error")

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# KEEP")
        gr.Markdown("## Kalman-Inspired Feature Propagation for Video Face Super-Resolution")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href='https://jnjaby.github.io/projects/KEEP/'>
                <img src='https://img.shields.io/badge/Project-Page-Green'>
            </a> 
            <a href='https://arxiv.org/abs/2408.05205'>
                <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
            </a>
        </div>
        """)
        with gr.Row():
            with gr.Column():
                input_video = gr.Video(label="Input Video")
                is_cropped_and_aligned = gr.Checkbox(label="Synthetic data", info="Is your input video ready with cropped and aligned faces ?", value=False)
                submit_btn = gr.Button("Submit")
                gr.Examples(
                    examples = [
                        ["./assets/examples/synthetic_1.mp4", True],
                        ["./assets/examples/synthetic_2.mp4", True],
                        ["./assets/examples/synthetic_3.mp4", True],
                        ["./assets/examples/synthetic_4.mp4", True],
                        ["./assets/examples/real_1.mp4", False],
                        ["./assets/examples/real_2.mp4", False],
                        ["./assets/examples/real_3.mp4", False],
                        ["./assets/examples/real_4.mp4", False]
                    ],
                    fn = infer,
                    inputs = [input_video, is_cropped_and_aligned],
                    outputs = [result_video, error_output],
                    run_on_click = False,
                    cache_examples = "lazy"
                )
            
            with gr.Column():
                result_video.render()
                error_output.render()

    submit_btn.click(
        fn = infer, 
        inputs = [input_video, is_cropped_and_aligned],
        outputs = [result_video, error_output],
        show_api=False
    )

demo.queue().launch(show_error=True, show_api=False)