File size: 3,380 Bytes
7ae0a5f 75c398d 7ae0a5f c263cca 75c398d 9e85783 7ae0a5f fd14c39 8f43f6f 75c398d 8f43f6f 9e85783 c263cca 8f43f6f c263cca 8f43f6f 9a07bf0 8f43f6f a263f51 c263cca a263f51 75c398d 7ae0a5f a263f51 8f43f6f 75c398d 9a07bf0 8f43f6f 7ae0a5f a263f51 7ae0a5f c263cca e16425f 7ae0a5f 9a07bf0 7ae0a5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from PIL import Image
import io
import importlib
def check_transformers_version():
import transformers
return transformers.__version__
@st.cache_resource
def load_model():
model_name = "Qwen/Qwen2-VL-7B-Instruct"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
return tokenizer, model
except Exception as e:
st.error(f"Ошибка при загрузке модели: {str(e)}")
return None, None
def generate_response(prompt, image, tokenizer, model):
if tokenizer is None or model is None:
return "Модель не загружена. Пожалуйста, проверьте ошибки выше."
try:
if image:
image = Image.open(image).convert('RGB')
inputs = tokenizer(prompt, images=[image], return_tensors='pt').to(model.device)
else:
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
return f"Ошибка при генерации ответа: {str(e)}"
st.title("Чат с Qwen VL-7B-Instruct")
transformers_version = check_transformers_version()
st.info(f"Версия transformers: {transformers_version}")
tokenizer, model = load_model()
if tokenizer is None or model is None:
st.warning("Модель не загружена. Приложение может работать некорректно.")
st.info("Попробуйте установить последнюю версию transformers: pip install transformers --upgrade")
else:
st.success("Модель успешно загружена!")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if "image" in message:
st.image(message["image"])
prompt = st.chat_input("Введите ваше сообщение")
uploaded_file = st.file_uploader("Загрузите изображение (необязательно)", type=["png", "jpg", "jpeg"])
if prompt or uploaded_file:
if uploaded_file:
image = Image.open(uploaded_file)
st.session_state.messages.append({"role": "user", "content": prompt or "Опишите это изображение", "image": uploaded_file})
with st.chat_message("user"):
if prompt:
st.markdown(prompt)
st.image(image)
else:
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Генерация ответа..."):
response = generate_response(prompt, uploaded_file, tokenizer, model)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response}) |