yejunliang23 commited on
Commit
eb04357
·
unverified ·
1 Parent(s): 3c5e0ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -8
app.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import torch
3
  from threading import Thread
4
  import gradio as gr
5
- from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
6
 
7
  # 3D mesh dependencies
8
  import trimesh
@@ -13,20 +14,63 @@ import tempfile
13
  # --------- Configuration & Model Loading ---------
14
  MODEL_DIR = "Qwen/Qwen2-VL-7B-Instruct"
15
  # Load processor, tokenizer, model for Qwen2.5-VL
16
- processor = AutoProcessor.from_pretrained(MODEL_DIR)
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
18
- model = AutoModelForCausalLM.from_pretrained(
19
  MODEL_DIR,
20
  torch_dtype=torch.float16,
21
  device_map="auto",
22
  trust_remote_code=True
23
  )
24
-
25
- # Terminator tokens
26
- terminators = [tokenizer.eos_token_id]
27
 
28
  # --------- Chat Inference Function ---------
29
- def chat_qwen_vl(message: str, history: list, temperature: float = 0.7, max_new_tokens: int = 1024):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
  Stream chat response from local Qwen2.5-VL model.
32
  """
 
2
  import torch
3
  from threading import Thread
4
  import gradio as gr
5
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
 
8
  # 3D mesh dependencies
9
  import trimesh
 
14
  # --------- Configuration & Model Loading ---------
15
  MODEL_DIR = "Qwen/Qwen2-VL-7B-Instruct"
16
  # Load processor, tokenizer, model for Qwen2.5-VL
17
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
 
18
  MODEL_DIR,
19
  torch_dtype=torch.float16,
20
  device_map="auto",
21
  trust_remote_code=True
22
  )
23
+ processor = AutoProcessor.from_pretrained(MODEL_DIR)
 
 
24
 
25
  # --------- Chat Inference Function ---------
26
+ def chat_qwen_vl(messages):
27
+ # —— 原有多模态输入构造 —— #
28
+ text = processor.apply_chat_template(
29
+ messages, tokenize=False, add_generation_prompt=True
30
+ )
31
+ image_inputs, video_inputs = process_vision_info(messages)
32
+ inputs = processor(
33
+ text=[text],
34
+ images=image_inputs,
35
+ videos=video_inputs,
36
+ padding=True,
37
+ return_tensors="pt"
38
+ ).to(model.device)
39
+
40
+ # —— 流式生成部分 —— #
41
+ # 1. 构造 streamer,用 processor.tokenizer(AutoProcessor 内部自带 tokenizer)
42
+ streamer = TextIteratorStreamer(
43
+ processor.tokenizer,
44
+ timeout=10.0,
45
+ skip_prompt=True,
46
+ skip_special_tokens=True
47
+ )
48
+
49
+ # 2. 把 streamer 和生成参数一起传给 model.generate
50
+ gen_kwargs = dict(
51
+ **inputs, # 包含 input_ids, pixel_values, attention_mask 等
52
+ streamer=streamer, # 关键:挂载 streamer
53
+ top_k=1024,
54
+ max_new_tokens=1280,
55
+ temperature=0.1,
56
+ top_p=0.1,
57
+ eos_token_id=terminators, # 你的结束符 ID 列表
58
+ )
59
+ # 如果需要零温度贪心,则关闭采样
60
+ if gen_kwargs["temperature"] == 0:
61
+ gen_kwargs["do_sample"] = False
62
+
63
+ # 3. 在后台线程中启动生成
64
+ Thread(target=model.generate, kwargs=gen_kwargs).start()
65
+
66
+ # 4. 在主线程中实时读取并 yield
67
+ buffer = []
68
+ for chunk in streamer:
69
+ buffer.append(chunk)
70
+ # 每次拿到新片段就拼接并输出
71
+ yield "".join(buffer)
72
+
73
+ def chat_qwen_vl_(message: str, history: list, temperature: float = 0.7, max_new_tokens: int = 1024):
74
  """
75
  Stream chat response from local Qwen2.5-VL model.
76
  """