File size: 2,364 Bytes
c6f92cc
 
 
afc90f1
c6f92cc
 
84598ac
c6f92cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84598ac
1c3e1c7
 
c6f92cc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
### app.py
# User interface for the demo.
###

import os
import pandas as pd
import gradio as gr
from gradio_rich_textbox import RichTextbox

from demo import VideoCLSModel


def load_samples(data_root):
    sample_videos = []
    n_sample = len(os.listdir(f'{data_root}/csv'))
    for i in range(n_sample):
        df = pd.read_csv(f'{data_root}/csv/{i}.csv')
        vid = df['id'].values[0]
        sample_videos.append(f'{data_root}/video/{vid}.mp4')

    return sample_videos

def format_pred(pred, gt):
    tp = '[color=green]{}[/color]'
    fp = '[color=red]{}[/color]'
    fmt_pred = []
    for x in pred:
        if x in gt:
            fmt_pred.append(tp.format(x))
        else:
            fmt_pred.append(fp.format(x))

    return ', '.join(fmt_pred)

def main():
    lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml")
    egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml")
    sample_videos = load_samples('data/charades_ego')
    print(sample_videos)

    def predict(idx):
        zeroshot_action, gt_action = lavila.predict(idx)
        egovpa_action, gt_action = egovpa.predict(idx)
        zeroshot_action = format_pred(zeroshot_action, gt_action)
        egovpa_action = format_pred(egovpa_action, gt_action)

        return gt_action, zeroshot_action, egovpa_action

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # Ego-VPA Demo
            Choose a sample video and click predict to view the results
            (<span style="color:green">correct</span>/<span style="color:red">incorrect</span>).
            """
        )

        with gr.Row():        
            with gr.Column():
                video = gr.PlayableVideo(label="video", height='300px', interactive=False, autoplay=True)
            with gr.Column():
                idx = gr.Number(label="Idx", visible=False)
                label = RichTextbox(label="Ground Truth", visible=False)
                zeroshot = RichTextbox(label="LaViLa (zero-shot) prediction")
                ours = RichTextbox(label="Ego-VPA prediction")
        btn = gr.Button("Predict", variant="primary")
        btn.click(predict, inputs=[idx], outputs=[label, zeroshot, ours])
        gr.Examples(examples=[[i, x] for i, x in enumerate(sample_videos)], inputs=[idx, video])

    demo.launch(share=True)


if __name__ == "__main__":
    main()