File size: 3,380 Bytes
7ae0a5f
75c398d
7ae0a5f
c263cca
 
75c398d
 
 
 
 
9e85783
7ae0a5f
 
fd14c39
8f43f6f
75c398d
8f43f6f
 
 
 
 
9e85783
c263cca
8f43f6f
 
c263cca
8f43f6f
 
 
9a07bf0
8f43f6f
 
 
 
 
 
 
 
 
 
a263f51
c263cca
a263f51
75c398d
 
 
7ae0a5f
a263f51
8f43f6f
 
75c398d
9a07bf0
 
8f43f6f
7ae0a5f
 
a263f51
7ae0a5f
 
 
c263cca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e16425f
7ae0a5f
9a07bf0
 
7ae0a5f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from PIL import Image
import io
import importlib

def check_transformers_version():
    import transformers
    return transformers.__version__

@st.cache_resource
def load_model():
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
        return tokenizer, model
    except Exception as e:
        st.error(f"Ошибка при загрузке модели: {str(e)}")
        return None, None

def generate_response(prompt, image, tokenizer, model):
    if tokenizer is None or model is None:
        return "Модель не загружена. Пожалуйста, проверьте ошибки выше."
    
    try:
        if image:
            image = Image.open(image).convert('RGB')
            inputs = tokenizer(prompt, images=[image], return_tensors='pt').to(model.device)
        else:
            inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=100)
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response
    except Exception as e:
        return f"Ошибка при генерации ответа: {str(e)}"

st.title("Чат с Qwen VL-7B-Instruct")

transformers_version = check_transformers_version()
st.info(f"Версия transformers: {transformers_version}")

tokenizer, model = load_model()

if tokenizer is None or model is None:
    st.warning("Модель не загружена. Приложение может работать некорректно.")
    st.info("Попробуйте установить последнюю версию transformers: pip install transformers --upgrade")
else:
    st.success("Модель успешно загружена!")

if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        if "image" in message:
            st.image(message["image"])

prompt = st.chat_input("Введите ваше сообщение")
uploaded_file = st.file_uploader("Загрузите изображение (необязательно)", type=["png", "jpg", "jpeg"])

if prompt or uploaded_file:
    if uploaded_file:
        image = Image.open(uploaded_file)
        st.session_state.messages.append({"role": "user", "content": prompt or "Опишите это изображение", "image": uploaded_file})
        with st.chat_message("user"):
            if prompt:
                st.markdown(prompt)
            st.image(image)
    else:
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
    
    with st.chat_message("assistant"):
        with st.spinner("Генерация ответа..."):
            response = generate_response(prompt, uploaded_file, tokenizer, model)
        st.markdown(response)
    
    st.session_state.messages.append({"role": "assistant", "content": response})