File size: 3,509 Bytes
0d89394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3af41f
 
 
 
 
0d89394
 
 
 
 
 
9f8e9e6
 
 
 
 
 
 
0d89394
9f8e9e6
bf0aeaf
9f8e9e6
 
0d89394
9f8e9e6
 
 
 
 
0d89394
9f8e9e6
 
0d89394
9f8e9e6
 
 
0d89394
9f8e9e6
 
0d89394
9f8e9e6
 
0d89394
9f8e9e6
0d89394
9f8e9e6
 
 
0d89394
9f8e9e6
 
0d89394
9f8e9e6
 
 
0d89394
bf0aeaf
 
0d89394
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import pandas as pd
import numpy as np
import torch
import clip
import tempfile

from tqdm import tqdm
from transformers import GPT2Tokenizer
from model import *
from inference import *

st.set_page_config(
    page_title="Video Analysis AI",
    page_icon="๐Ÿ•ถ๏ธ",
)

@st.cache_resource
def load_model():
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
    tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3large_based_on_gpt2')

    prefix_length = 50
    model_path = 'transformer_clip_gpt-007.pt'
    model = ClipCaptionModel('sberbank-ai/rugpt3small_based_on_gpt2', prefix_length=prefix_length)

    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.to(device)
    model.eval()

    return model, clip_model, preprocess, tokenizer

def _max_width_():
    max_width_str = f"max-width: 1400px;"
    st.markdown(
        f"""
    <style>
    .reportview-container .main .block-container{{
        {max_width_str}
    }}
    </style>
    """,
        unsafe_allow_html=True,
    )

_max_width_()


def main():
    model, clip_model, preprocess, tokenizer = load_model()
    prefix_length = 50

    st.title("๐Ÿฆพ Video Analysis for Education")
    st.header("")

    with st.sidebar.expander("โ„น๏ธ - About application", expanded=True):
        st.write(
            """
            - Upload the video
            - Make a question about the content of the video
            - Recieve answer according your question prompt
    	    """
        )


    uploaded_file = st.file_uploader("๐Ÿ“Œ Upload video: ", ['.mp4'])

    if uploaded_file is not None:
        st.write('success')
    else:
        st.write('no')

    # if play_video:
    #     video_bytes = uploaded_file.read()
    #     st.video(video_bytes)

    st.write("---")

    a, b = st.columns([4, 1])
    question = a.text_input(
        label="โ” Enter question prompt: ",
        placeholder="",
        # label_visibility="collapsed",
    )
    button = b.button("Send", use_container_width=True)

    if button: 
        # try:
            tfile = tempfile.NamedTemporaryFile(delete=False)
            tfile.write(uploaded_file.read())

            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            val_embeddings = []
            val_captions = []
            result = ''
            text = f'Question: {question}? Answer:'

            # read video -> get_ans
            video = read_video(tfile.name, transform=None, frames_num=4)

            if len(video) > 0:
                i = image_grid(video, 2, 2)
                image = preprocess(i).unsqueeze(0).to(device)

                with torch.no_grad():
                    prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)

                val_embeddings.append(prefix)
                val_captions.append(text)

            answers = []

            for i in tqdm(range(len(val_embeddings))):
                emb = val_embeddings[i]
                caption = val_captions[i]

                ans = get_ans(model, tokenizer, emb, prefix_length, caption)
                answers.append(ans['answer'])

            result = answers[0].split(' A: ')[0]
            
            res = st.text_input('โœ… Answer to the question', result, disabled=False)

        # except: 
        #     pass

if __name__ == '__main__':
    main()