vqa-analysis / app.py
dokster's picture
Update app.py
bf0aeaf
raw
history blame
3.51 kB
import streamlit as st
import pandas as pd
import numpy as np
import torch
import clip
import tempfile
from tqdm import tqdm
from transformers import GPT2Tokenizer
from model import *
from inference import *
st.set_page_config(
page_title="Video Analysis AI",
page_icon="๐Ÿ•ถ๏ธ",
)
@st.cache_resource
def load_model():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3large_based_on_gpt2')
prefix_length = 50
model_path = 'transformer_clip_gpt-007.pt'
model = ClipCaptionModel('sberbank-ai/rugpt3small_based_on_gpt2', prefix_length=prefix_length)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.to(device)
model.eval()
return model, clip_model, preprocess, tokenizer
def _max_width_():
max_width_str = f"max-width: 1400px;"
st.markdown(
f"""
<style>
.reportview-container .main .block-container{{
{max_width_str}
}}
</style>
""",
unsafe_allow_html=True,
)
_max_width_()
def main():
model, clip_model, preprocess, tokenizer = load_model()
prefix_length = 50
st.title("๐Ÿฆพ Video Analysis for Education")
st.header("")
with st.sidebar.expander("โ„น๏ธ - About application", expanded=True):
st.write(
"""
- Upload the video
- Make a question about the content of the video
- Recieve answer according your question prompt
"""
)
uploaded_file = st.file_uploader("๐Ÿ“Œ Upload video: ", ['.mp4'])
if uploaded_file is not None:
st.write('success')
else:
st.write('no')
# if play_video:
# video_bytes = uploaded_file.read()
# st.video(video_bytes)
st.write("---")
a, b = st.columns([4, 1])
question = a.text_input(
label="โ” Enter question prompt: ",
placeholder="",
# label_visibility="collapsed",
)
button = b.button("Send", use_container_width=True)
if button:
# try:
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
val_embeddings = []
val_captions = []
result = ''
text = f'Question: {question}? Answer:'
# read video -> get_ans
video = read_video(tfile.name, transform=None, frames_num=4)
if len(video) > 0:
i = image_grid(video, 2, 2)
image = preprocess(i).unsqueeze(0).to(device)
with torch.no_grad():
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
val_embeddings.append(prefix)
val_captions.append(text)
answers = []
for i in tqdm(range(len(val_embeddings))):
emb = val_embeddings[i]
caption = val_captions[i]
ans = get_ans(model, tokenizer, emb, prefix_length, caption)
answers.append(ans['answer'])
result = answers[0].split(' A: ')[0]
res = st.text_input('โœ… Answer to the question', result, disabled=False)
# except:
# pass
if __name__ == '__main__':
main()