Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import torch | |
import clip | |
import tempfile | |
from tqdm import tqdm | |
from transformers import GPT2Tokenizer | |
from model import * | |
from inference import * | |
st.set_page_config( | |
page_title="Video Analysis AI", | |
page_icon="๐ถ๏ธ", | |
) | |
def load_model(): | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False) | |
tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3large_based_on_gpt2') | |
prefix_length = 50 | |
model_path = 'transformer_clip_gpt-007.pt' | |
model = ClipCaptionModel('sberbank-ai/rugpt3small_based_on_gpt2', prefix_length=prefix_length) | |
model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
model.to(device) | |
model.eval() | |
return model, clip_model, preprocess, tokenizer | |
def _max_width_(): | |
max_width_str = f"max-width: 1400px;" | |
st.markdown( | |
f""" | |
<style> | |
.reportview-container .main .block-container{{ | |
{max_width_str} | |
}} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
_max_width_() | |
def main(): | |
model, clip_model, preprocess, tokenizer = load_model() | |
prefix_length = 50 | |
st.title("๐ฆพ Video Analysis for Education") | |
st.header("") | |
with st.sidebar.expander("โน๏ธ - About application", expanded=True): | |
st.write( | |
""" | |
- Upload the video | |
- Make a question about the content of the video | |
- Recieve answer according your question prompt | |
""" | |
) | |
uploaded_file = st.file_uploader("๐ Upload video: ", ['.mp4']) | |
if uploaded_file is not None: | |
st.write('success') | |
else: | |
st.write('no') | |
# if play_video: | |
# video_bytes = uploaded_file.read() | |
# st.video(video_bytes) | |
st.write("---") | |
a, b = st.columns([4, 1]) | |
question = a.text_input( | |
label="โ Enter question prompt: ", | |
placeholder="", | |
# label_visibility="collapsed", | |
) | |
button = b.button("Send", use_container_width=True) | |
if button: | |
# try: | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(uploaded_file.read()) | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
val_embeddings = [] | |
val_captions = [] | |
result = '' | |
text = f'Question: {question}? Answer:' | |
# read video -> get_ans | |
video = read_video(tfile.name, transform=None, frames_num=4) | |
if len(video) > 0: | |
i = image_grid(video, 2, 2) | |
image = preprocess(i).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
val_embeddings.append(prefix) | |
val_captions.append(text) | |
answers = [] | |
for i in tqdm(range(len(val_embeddings))): | |
emb = val_embeddings[i] | |
caption = val_captions[i] | |
ans = get_ans(model, tokenizer, emb, prefix_length, caption) | |
answers.append(ans['answer']) | |
result = answers[0].split(' A: ')[0] | |
res = st.text_input('โ Answer to the question', result, disabled=False) | |
# except: | |
# pass | |
if __name__ == '__main__': | |
main() |