ffgtv3 commited on
Commit
8f43f6f
·
verified ·
1 Parent(s): d126618

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -7,27 +7,40 @@ import io
7
  @st.cache_resource
8
  def load_model():
9
  model_name = "Qwen/Qwen2-VL-7B-Instruct"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
12
- return tokenizer, model
 
 
 
 
13
 
14
  def generate_response(prompt, image, tokenizer, model):
15
- if image:
16
- image = Image.open(image).convert('RGB')
17
- inputs = tokenizer.from_pretrained(prompt, images=[image], return_tensors='pt').to(model.device)
18
- else:
19
- inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
20
-
21
- with torch.no_grad():
22
- outputs = model.generate(**inputs, max_new_tokens=100)
23
 
24
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- return response
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  st.title("Чат с Qwen VL-7B-Instruct")
28
 
29
  tokenizer, model = load_model()
30
 
 
 
 
31
  if "messages" not in st.session_state:
32
  st.session_state.messages = []
33
 
 
7
  @st.cache_resource
8
  def load_model():
9
  model_name = "Qwen/Qwen2-VL-7B-Instruct"
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
13
+ return tokenizer, model
14
+ except Exception as e:
15
+ st.error(f"Ошибка при загрузке модели: {str(e)}")
16
+ return None, None
17
 
18
  def generate_response(prompt, image, tokenizer, model):
19
+ if tokenizer is None or model is None:
20
+ return "Модель не загружена. Пожалуйста, проверьте ошибки выше."
 
 
 
 
 
 
21
 
22
+ try:
23
+ if image:
24
+ image = Image.open(image).convert('RGB')
25
+ inputs = tokenizer.from_pretrained(prompt, images=[image], return_tensors='pt').to(model.device)
26
+ else:
27
+ inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
28
+
29
+ with torch.no_grad():
30
+ outputs = model.generate(**inputs, max_new_tokens=100)
31
+
32
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
+ return response
34
+ except Exception as e:
35
+ return f"Ошибка при генерации ответа: {str(e)}"
36
 
37
  st.title("Чат с Qwen VL-7B-Instruct")
38
 
39
  tokenizer, model = load_model()
40
 
41
+ if tokenizer is None or model is None:
42
+ st.warning("Модель не загружена. Приложение может работать некорректно.")
43
+
44
  if "messages" not in st.session_state:
45
  st.session_state.messages = []
46