Spaces:
Sleeping
Sleeping
File size: 2,982 Bytes
0d89394 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
import pandas as pd
import numpy as np
import torch
import clip
import tempfile
from tqdm import tqdm
from transformers import GPT2Tokenizer
from model import *
from inference import *
st.set_page_config(
page_title="Video Analysis AI",
page_icon="๐ถ๏ธ",
)
@st.cache_resource
def load_model():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3large_based_on_gpt2')
prefix_length = 50
model_path = 'transformer_clip_gpt-007.pt'
model = ClipCaptionModel('sberbank-ai/rugpt3small_based_on_gpt2', prefix_length=prefix_length)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(device)
model.eval()
return model, clip_model, preprocess, tokenizer
def _max_width_():
max_width_str = f"max-width: 1400px;"
st.markdown(
f"""
<style>
.reportview-container .main .block-container{{
{max_width_str}
}}
</style>
""",
unsafe_allow_html=True,
)
_max_width_()
def main():
model, clip_model, preprocess, tokenizer = load_model()
prefix_length = 50
st.title("๐ฆพ Video Analysis for Education")
st.header("")
with st.sidebar.expander("โน๏ธ - About application", expanded=True):
st.write(
"""
- Upload the video
- Make a question about the content of the video
- Recieve answer according your question prompt
"""
)
uploaded_file = st.file_uploader("๐ Upload video: ", ['.mp4'])
# if play_video:
# video_bytes = uploaded_file.read()
# st.video(video_bytes)
st.write("---")
question = st.text_input("โ Enter question prompt: ", "")
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
val_embeddings = []
val_captions = []
result = ''
text = f'Question: {question}? Answer:'
#read video -> get_ans
video = read_video(tfile.name, transform=None, frames_num=4)
if len(video) > 0:
i = image_grid(video, 2, 2)
image = preprocess(i).unsqueeze(0).to(device)
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
val_embeddings.append(prefix)
val_captions.append(text)
answers = []
for i in tqdm(range(len(val_embeddings))):
emb = val_embeddings[i]
caption = val_captions[i]
ans = get_ans(model, tokenizer, emb, prefix_length, caption)
answers.append(ans['answer'])
result = answers[0].split(' A: ')[0]
res = st.text_input('โ
Answer to the question', result, disabled=False)
if __name__ == '__main__':
main() |