VinitT's picture
Update app.py
53617f0 verified
raw
history blame
5.32 kB
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
# Step 1: Load the model
def load_model():
st.write("Loading the model...")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
st.write("Model loaded successfully!")
return processor, model, device
# Step 2: Upload image or video
def upload_media():
return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)
# Step 3: Enter your question
def get_user_question():
return st.text_input("Ask a question about the images or videos:")
# Process image
def process_image(uploaded_file):
image = Image.open(uploaded_file)
image = image.resize((512, 512)) # Reduce size to save memory
st.image(image, caption='Uploaded Image.', use_column_width=True)
return image
# Process video
def process_video(uploaded_file):
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
cap = cv2.VideoCapture(tfile.name)
ret, frame = cap.read()
cap.release()
if not ret:
st.error("Failed to read the video file.")
return None
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
image = image.resize((256, 256)) # Reduce size to save memory
st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
return image
# Generate description
def generate_description(processor, model, device, image, user_question):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": user_question},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return output_text[0]
# Generate story
def generate_story(descriptions):
combined_text = " ".join(descriptions)
prompt_template = PromptTemplate(
input_variables=["descriptions"],
template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
)
ollama_llm = Ollama(model="llama3.1")
output_parser = StrOutputParser()
chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser)
return chain.run({"descriptions": combined_text})
# Main function to control the flow
def main():
st.title("Media Story Generator")
# Step 1: Load the model
processor, model, device = load_model()
# Step 2: Upload image or video
uploaded_files = upload_media()
if uploaded_files:
# Step 3: Enter your question
user_question = get_user_question()
if user_question:
# Step 4: Generate description
st.write("Step 4: Generate description")
generate_description_button = st.button("Generate Descriptions", key="generate_descriptions")
if generate_description_button:
all_output_texts = []
for idx, uploaded_file in enumerate(uploaded_files):
file_type = uploaded_file.type.split('/')[0]
image = None
if file_type == 'image':
image = process_image(uploaded_file)
elif file_type == 'video':
image = process_video(uploaded_file)
else:
st.error("Unsupported file type.")
continue
if image:
description = generate_description(processor, model, device, image, user_question)
st.write(f"Description for file {idx + 1}:")
st.write(description)
all_output_texts.append(description)
# Clear memory after processing each file
del image
torch.cuda.empty_cache()
torch.manual_seed(0)
if all_output_texts:
# Step 5: Generate story
st.write("Step 5: Generate story")
generate_story_button = st.button("Generate Story", key="generate_story")
if generate_story_button:
story = generate_story(all_output_texts)
st.write("Generated Story:")
st.write(story)
if __name__ == "__main__":
main()