File size: 5,706 Bytes
fd19cdd
6255a7a
fd19cdd
 
36d8cb0
 
6255a7a
 
 
fd19cdd
2f43b7c
fd19cdd
2f43b7c
fd19cdd
36d8cb0
6255a7a
36d8cb0
 
fd19cdd
36d8cb0
fd19cdd
fbe9130
36d8cb0
fbe9130
 
36d8cb0
 
fbe9130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d027b05
 
 
 
 
fbe9130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d027b05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbe9130
 
 
 
 
 
 
 
 
 
 
 
6255a7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fbe9130
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import streamlit as st
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import torch
import cv2
import tempfile
from langchain import LLMChain, PromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser

# Load the processor and model directly
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Streamlit app
st.title("Media Description Generator")

uploaded_files = st.file_uploader("Choose images or videos...", type=["jpg", "jpeg", "png", "mp4", "avi", "mov"], accept_multiple_files=True)

if uploaded_files:
    user_question = st.text_input("Ask a question about the images or videos:")

    if user_question:
        all_output_texts = []  # Initialize an empty list to store all output texts

        for uploaded_file in uploaded_files:
            file_type = uploaded_file.type.split('/')[0]

            if file_type == 'image':
                # Open the image
                image = Image.open(uploaded_file)
                # Resize image to reduce memory usage
                image = image.resize((512, 512))
                st.image(image, caption='Uploaded Image.', use_column_width=True)
                st.write("Generating description...")

            elif file_type == 'video':
                # Save the uploaded video to a temporary file
                tfile = tempfile.NamedTemporaryFile(delete=False)
                tfile.write(uploaded_file.read())

                # Open the video file
                cap = cv2.VideoCapture(tfile.name)

                # Extract the first frame
                ret, frame = cap.read()
                if not ret:
                    st.error("Failed to read the video file.")
                    continue
                else:
                    # Convert the frame to an image
                    image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    # Resize image to reduce memory usage
                    image = image.resize((512, 512))
                    st.image(image, caption='First Frame of Uploaded Video.', use_column_width=True)
                    st.write("Generating description...")

                # Release the video capture object
                cap.release()

            else:
                st.error("Unsupported file type.")
                continue

            # Ensure the image is loaded correctly
            if image is None:
                st.error("Failed to load the image.")
                continue

            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                        },
                        {"type": "text", "text": user_question},
                    ],
                }
            ]

            # Preparation for inference
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

            # Pass the image to the processor
            inputs = processor(
                text=[text],
                images=[image],
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(device)  # Ensure inputs are on the same device as the model

            # Inference: Generation of the output
            try:
                generated_ids = model.generate(**inputs, max_new_tokens=512)
                generated_ids_trimmed = [
                    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                output_text = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )

                st.write("Description:")
                st.write(output_text[0])

                # Append the output text to the list
                all_output_texts.append(output_text[0])

            except Exception as e:
                st.error(f"Error during generation: {e}")
                continue

            # Clear memory after processing each file
            del image, inputs, generated_ids, generated_ids_trimmed, output_text
            torch.cuda.empty_cache()
            torch.manual_seed(0)  # Reset the seed to ensure reproducibility

        # Combine all descriptions into a single text
        combined_text = " ".join(all_output_texts)

        # Create a custom prompt
        custom_prompt = f"Based on the following descriptions, create a short story:\n\n{combined_text}\n\nStory:"

        # Define the prompt template for LangChain
        prompt_template = PromptTemplate(
            input_variables=["descriptions"],
            template="Based on the following descriptions, create a short story:\n\n{descriptions}\n\nStory:"
        )

        # Create the LLMChain with the Ollama model
        ollama_llm = Ollama(model="llama3.1")
        output_parser = StrOutputParser()
        chain = LLMChain(
            llm=ollama_llm,
            prompt=prompt_template,
            output_parser=output_parser
        )

        # Generate the story using LangChain
        story = chain.run({"descriptions": combined_text})

        # Display the generated story
        st.write("Generated Story:")
        st.write(story)