File size: 3,073 Bytes
7ae0a5f
9a07bf0
 
7ae0a5f
c263cca
 
9e85783
7ae0a5f
 
fd14c39
8f43f6f
9a07bf0
8f43f6f
 
 
 
 
9e85783
c263cca
8f43f6f
 
c263cca
8f43f6f
 
 
9a07bf0
8f43f6f
 
 
 
 
 
 
 
 
 
a263f51
c263cca
a263f51
7ae0a5f
a263f51
8f43f6f
 
9a07bf0
 
8f43f6f
7ae0a5f
 
a263f51
7ae0a5f
 
 
c263cca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e16425f
7ae0a5f
9a07bf0
 
7ae0a5f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import streamlit as st
from transformers import AutoModelForCausalLM
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
import torch
from PIL import Image
import io

@st.cache_resource
def load_model():
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
    try:
        tokenizer = Qwen2Tokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
        return tokenizer, model
    except Exception as e:
        st.error(f"Ошибка при загрузке модели: {str(e)}")
        return None, None

def generate_response(prompt, image, tokenizer, model):
    if tokenizer is None or model is None:
        return "Модель не загружена. Пожалуйста, проверьте ошибки выше."
    
    try:
        if image:
            image = Image.open(image).convert('RGB')
            inputs = tokenizer(prompt, images=[image], return_tensors='pt').to(model.device)
        else:
            inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=100)
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response
    except Exception as e:
        return f"Ошибка при генерации ответа: {str(e)}"

st.title("Чат с Qwen VL-7B-Instruct")

tokenizer, model = load_model()

if tokenizer is None or model is None:
    st.warning("Модель не загружена. Приложение может работать некорректно.")
else:
    st.success("Модель успешно загружена!")

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        if "image" in message:
            st.image(message["image"])

prompt = st.chat_input("Введите ваше сообщение")
uploaded_file = st.file_uploader("Загрузите изображение (необязательно)", type=["png", "jpg", "jpeg"])

if prompt or uploaded_file:
    if uploaded_file:
        image = Image.open(uploaded_file)
        st.session_state.messages.append({"role": "user", "content": prompt or "Опишите это изображение", "image": uploaded_file})
        with st.chat_message("user"):
            if prompt:
                st.markdown(prompt)
            st.image(image)
    else:
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
    
    with st.chat_message("assistant"):
        with st.spinner("Генерация ответа..."):
            response = generate_response(prompt, uploaded_file, tokenizer, model)
        st.markdown(response)
    
    st.session_state.messages.append({"role": "assistant", "content": response})