File size: 5,347 Bytes
fd19cdd
6255a7a
fd19cdd
 
36d8cb0
 
6255a7a
 
 
fd19cdd
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4d4b77
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d662e5
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
 
53617f0
ab2cf62
 
8603deb
ab2cf62
53617f0
ab2cf62
 
 
 
 
 
 
 
 
 
 
 
 
53617f0
ab2cf62
 
e4d4b77
 
 
 
5db4df9
 
 
 
 
 
 
 
 
ab2cf62
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser

# Step 1: Load the model
def load_model():
    st.write("Loading the model...")
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    st.write("Model loaded successfully!")
    return processor, model, device

# Step 2: Upload image or video
def upload_media():
    return st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)

# Step 3: Enter your question
def get_user_question():
    return st.text_input("Ask a question about the images or videos:")

# Process image
def process_image(uploaded_file):
    image = Image.open(uploaded_file)
    image = image.resize((256,256))  # Reduce size to save memory
    st.image(image, caption='Uploaded Image.', use_column_width=True)
    return image

# Process video
def process_video(uploaded_file):
    tfile = tempfile.NamedTemporaryFile(delete=False)
    tfile.write(uploaded_file.read())
    cap = cv2.VideoCapture(tfile.name)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        st.error("Failed to read the video file.")
        return None
    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    image = image.resize((256, 256))  # Reduce size to save memory
    st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
    return image

# Generate description
def generate_description(processor, model, device, image, user_question):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": user_question},
            ],
        }
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return output_text[0]

# Generate story
def generate_story(descriptions):
    combined_text = " ".join(descriptions)
    prompt_template = PromptTemplate(
        input_variables=["descriptions"],
        template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
    )
    ollama_llm = Ollama(model="llama3.1")
    output_parser = StrOutputParser()
    chain = LLMChain(llm=ollama_llm, prompt=prompt_template, output_parser=output_parser)
    return chain.run({"descriptions": combined_text})

# Main function to control the flow
def main():
    st.title("Media Story Generator")

    # Step 1: Load the model
    processor, model, device = load_model()

    # Step 2: Upload image or video
    uploaded_files = upload_media()

    if uploaded_files:
        # Step 3: Enter your question
        user_question = get_user_question()

        if user_question:
            # Step 4: Generate description
            st.write("Step 4: Generate description")
            generate_description_button = st.button("Generate Descriptions", key="generate_descriptions")

            if generate_description_button:
                all_output_texts = []

                for idx, uploaded_file in enumerate(uploaded_files):
                    file_type = uploaded_file.type.split('/')[0]
                    image = None

                    if file_type == 'image':
                        image = process_image(uploaded_file)
                    elif file_type == 'video':
                        image = process_video(uploaded_file)
                    else:
                        st.error("Unsupported file type.")
                        continue

                    if image:
                        description = generate_description(processor, model, device, image, user_question)
                        st.write(f"Description for file {idx + 1}:")
                        st.write(description)
                        all_output_texts.append(description)
                
                # Store descriptions in session state
                st.session_state["all_output_texts"] = all_output_texts

            # Check if descriptions are available in session state
            if "all_output_texts" in st.session_state and st.session_state["all_output_texts"]:
                st.write("Generate story")
                generate_story_button = st.button("Generate Story", key="generate_story")
        
                if generate_story_button:
                    story = generate_story(st.session_state["all_output_texts"])
                    st.write("Generated Story:")
                    st.write(story)

if __name__ == "__main__":
    main()