yejunliang23 commited on
Commit
aa6782f
·
unverified ·
1 Parent(s): f77d097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -226
app.py CHANGED
@@ -1,232 +1,148 @@
1
- # Copyright (c) Alibaba Cloud.
2
- #
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
  import os
6
- import numpy as np
7
- from urllib3.exceptions import HTTPError
8
- os.system('pip install dashscope modelscope oss2 -U')
9
-
10
- from argparse import ArgumentParser
11
- from pathlib import Path
12
-
13
- import copy
14
  import gradio as gr
15
- import oss2
16
- import os
17
- import re
18
- import secrets
 
 
19
  import tempfile
20
- import requests
21
- from http import HTTPStatus
22
- from dashscope import MultiModalConversation
23
- import dashscope
24
-
25
- API_KEY = os.environ['API_KEY']
26
- dashscope.api_key = API_KEY
27
-
28
- REVISION = 'v1.0.4'
29
- BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
30
- PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
31
-
32
-
33
- def _get_args():
34
- parser = ArgumentParser()
35
- parser.add_argument("--revision", type=str, default=REVISION)
36
- parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
37
-
38
- parser.add_argument("--share", action="store_true", default=False,
39
- help="Create a publicly shareable link for the interface.")
40
- parser.add_argument("--inbrowser", action="store_true", default=False,
41
- help="Automatically launch the interface in a new tab on the default browser.")
42
- parser.add_argument("--server-port", type=int, default=7860,
43
- help="Demo server port.")
44
- parser.add_argument("--server-name", type=str, default="127.0.0.1",
45
- help="Demo server name.")
46
-
47
- args = parser.parse_args()
48
- return args
49
-
50
- def _parse_text(text):
51
- lines = text.split("\n")
52
- lines = [line for line in lines if line != ""]
53
- count = 0
54
- for i, line in enumerate(lines):
55
- if "```" in line:
56
- count += 1
57
- items = line.split("`")
58
- if count % 2 == 1:
59
- lines[i] = f'<pre><code class="language-{items[-1]}">'
60
- else:
61
- lines[i] = f"<br></code></pre>"
62
- else:
63
- if i > 0:
64
- if count % 2 == 1:
65
- line = line.replace("`", r"\`")
66
- line = line.replace("<", "&lt;")
67
- line = line.replace(">", "&gt;")
68
- line = line.replace(" ", "&nbsp;")
69
- line = line.replace("*", "&ast;")
70
- line = line.replace("_", "&lowbar;")
71
- line = line.replace("-", "&#45;")
72
- line = line.replace(".", "&#46;")
73
- line = line.replace("!", "&#33;")
74
- line = line.replace("(", "&#40;")
75
- line = line.replace(")", "&#41;")
76
- line = line.replace("$", "&#36;")
77
- lines[i] = "<br>" + line
78
- text = "".join(lines)
79
- return text
80
-
81
-
82
- def _remove_image_special(text):
83
- text = text.replace('<ref>', '').replace('</ref>', '')
84
- return re.sub(r'<box>.*?(</box>|$)', '', text)
85
-
86
-
87
- def is_video_file(filename):
88
- video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
89
- return any(filename.lower().endswith(ext) for ext in video_extensions)
90
-
91
-
92
- def _launch_demo(args):
93
- uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
94
- Path(tempfile.gettempdir()) / "gradio"
95
- )
96
 
97
- def predict(_chatbot, task_history):
98
- chat_query = _chatbot[-1][0]
99
- query = task_history[-1][0]
100
- if len(chat_query) == 0:
101
- _chatbot.pop()
102
- task_history.pop()
103
- return _chatbot
104
- print("User: " + _parse_text(query))
105
- history_cp = copy.deepcopy(task_history)
106
- full_response = ""
107
- messages = []
108
- content = []
109
- for q, a in history_cp:
110
- if isinstance(q, (tuple, list)):
111
- if is_video_file(q[0]):
112
- content.append({'video': f'file://{q[0]}'})
113
- else:
114
- content.append({'image': f'file://{q[0]}'})
115
- else:
116
- content.append({'text': q})
117
- messages.append({'role': 'user', 'content': content})
118
- messages.append({'role': 'assistant', 'content': [{'text': a}]})
119
- content = []
120
- messages.pop()
121
- responses = MultiModalConversation.call(
122
- model='qwen2.5-vl-32b-instruct', messages=messages, stream=True,
123
- )
124
- for response in responses:
125
- if not response.status_code == HTTPStatus.OK:
126
- raise HTTPError(f'response.code: {response.code}\nresponse.message: {response.message}')
127
- response = response.output.choices[0].message.content
128
- response_text = []
129
- for ele in response:
130
- if 'text' in ele:
131
- response_text.append(ele['text'])
132
- elif 'box' in ele:
133
- response_text.append(ele['box'])
134
- response_text = ''.join(response_text)
135
- _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(response_text))
136
- yield _chatbot
137
-
138
- if len(response) > 1:
139
- result_image = response[-1]['result_image']
140
- resp = requests.get(result_image)
141
- os.makedirs(uploaded_file_dir, exist_ok=True)
142
- name = f"tmp{secrets.token_hex(20)}.jpg"
143
- filename = os.path.join(uploaded_file_dir, name)
144
- with open(filename, 'wb') as f:
145
- f.write(resp.content)
146
- response = ''.join(r['box'] if 'box' in r else r['text'] for r in response[:-1])
147
- _chatbot.append((None, (filename,)))
148
- else:
149
- response = response[0]['text']
150
- _chatbot[-1] = (_parse_text(chat_query), response)
151
- full_response = _parse_text(response)
152
-
153
- task_history[-1] = (query, full_response)
154
- print("Qwen2.5-VL-Chat: " + _parse_text(full_response))
155
- yield _chatbot
156
-
157
-
158
- def regenerate(_chatbot, task_history):
159
- if not task_history:
160
- return _chatbot
161
- item = task_history[-1]
162
- if item[1] is None:
163
- return _chatbot
164
- task_history[-1] = (item[0], None)
165
- chatbot_item = _chatbot.pop(-1)
166
- if chatbot_item[0] is None:
167
- _chatbot[-1] = (_chatbot[-1][0], None)
168
- else:
169
- _chatbot.append((chatbot_item[0], None))
170
- _chatbot_gen = predict(_chatbot, task_history)
171
- for _chatbot in _chatbot_gen:
172
- yield _chatbot
173
-
174
- def add_text(history, task_history, text):
175
- task_text = text
176
- history = history if history is not None else []
177
- task_history = task_history if task_history is not None else []
178
- history = history + [(_parse_text(text), None)]
179
- task_history = task_history + [(task_text, None)]
180
- return history, task_history, ""
181
-
182
- def add_file(history, task_history, file):
183
- history = history if history is not None else []
184
- task_history = task_history if task_history is not None else []
185
- history = history + [((file.name,), None)]
186
- task_history = task_history + [((file.name,), None)]
187
- return history, task_history
188
-
189
- def reset_user_input():
190
- return gr.update(value="")
191
-
192
- def reset_state(task_history):
193
- task_history.clear()
194
- return []
195
-
196
- with gr.Blocks() as demo:
197
- gr.Markdown("""<center><font size=3> Qwen2.5-VL-32B-Instruct Demo </center>""")
198
-
199
- chatbot = gr.Chatbot(label='Qwen2.5-VL-32B-Instruct', elem_classes="control-height", height=500)
200
- query = gr.Textbox(lines=2, label='Input')
201
- task_history = gr.State([])
202
-
203
- with gr.Row():
204
- addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image", "video"])
205
- submit_btn = gr.Button("🚀 Submit (发送)")
206
- regen_btn = gr.Button("🤔️ Regenerate (重试)")
207
- empty_bin = gr.Button("🧹 Clear History (清除历史)")
208
-
209
- submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
210
- predict, [chatbot, task_history], [chatbot], show_progress=True
211
- )
212
- submit_btn.click(reset_user_input, [], [query])
213
- empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
214
- regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
215
- addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
216
-
217
-
218
- demo.queue(default_concurrency_limit=40).launch(
219
- share=args.share,
220
- # inbrowser=args.inbrowser,
221
- # server_port=args.server_port,
222
- # server_name=args.server_name,
223
  )
224
 
225
-
226
- def main():
227
- args = _get_args()
228
- _launch_demo(args)
229
-
230
-
231
- if __name__ == '__main__':
232
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import torch
3
+ from threading import Thread
 
 
 
 
 
 
4
  import gradio as gr
5
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+
7
+ # 3D mesh dependencies
8
+ import trimesh
9
+ from trimesh.exchange.gltf import export_glb
10
+ import numpy as np
11
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # --------- Configuration & Model Loading ---------
14
+ MODEL_DIR = "./qwen2.5-vl-32b-instruct"
15
+ # Load processor, tokenizer, model for Qwen2.5-VL
16
+ processor = AutoProcessor.from_pretrained(MODEL_DIR)
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_DIR,
20
+ torch_dtype=torch.float16,
21
+ device_map="auto",
22
+ trust_remote_code=True
23
+ )
24
+
25
+ # Terminator tokens
26
+ terminators = [tokenizer.eos_token_id]
27
+
28
+ # --------- Chat Inference Function ---------
29
+ def chat_qwen_vl(message: str, history: list, temperature: float = 0.7, max_new_tokens: int = 1024):
30
+ """
31
+ Stream chat response from local Qwen2.5-VL model.
32
+ """
33
+ # Build conversation prompt
34
+ conv = []
35
+ for u, a in history:
36
+ conv.append(f"<user> {u}")
37
+ conv.append(f"<assistant> {a}")
38
+ conv.append(f"<user> {message}")
39
+ conv.append("<assistant>")
40
+
41
+ # Tokenize
42
+ inputs = tokenizer(
43
+ "\n".join(conv),
44
+ return_tensors="pt",
45
+ truncation=True,
46
+ max_length=4096
47
+ ).to(model.device)
48
+
49
+ # Create streamer
50
+ streamer = TextIteratorStreamer(
51
+ tokenizer,
52
+ timeout=10.0,
53
+ skip_prompt=True,
54
+ skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  )
56
 
57
+ # Generation kwargs
58
+ gen_kwargs = dict(
59
+ input_ids=inputs.input_ids,
60
+ attention_mask=inputs.attention_mask,
61
+ streamer=streamer,
62
+ do_sample=(temperature > 0),
63
+ temperature=temperature,
64
+ max_new_tokens=max_new_tokens,
65
+ eos_token_id=terminators,
66
+ )
67
+ if temperature == 0:
68
+ gen_kwargs["do_sample"] = False
69
+
70
+ # Launch generation in thread
71
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
72
+ thread.start()
73
+
74
+ # Stream outputs
75
+ output_chunks = []
76
+ for chunk in streamer:
77
+ output_chunks.append(chunk)
78
+ yield "".join(output_chunks)
79
+
80
+ # --------- 3D Mesh Coloring Function ---------
81
+ def apply_gradient_color(mesh_text: str) -> str:
82
+ """
83
+ Apply a Y-axis-based gradient RGBA color to OBJ mesh text and export as GLB.
84
+ """
85
+ # Write OBJ to temp file
86
+ tmp = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
87
+ tmp.write(mesh_text.encode('utf-8'))
88
+ tmp.flush()
89
+ tmp.close()
90
+
91
+ mesh = trimesh.load_mesh(tmp.name, file_type='obj')
92
+ vertices = mesh.vertices
93
+ ys = vertices[:, 1]
94
+ y_norm = (ys - ys.min()) / (ys.max() - ys.min())
95
+
96
+ colors = np.zeros((len(vertices), 4))
97
+ colors[:, 0] = y_norm
98
+ colors[:, 2] = 1 - y_norm
99
+ colors[:, 3] = 1.0
100
+ mesh.visual.vertex_colors = colors
101
+
102
+ glb_path = tmp.name.replace('.obj', '.glb')
103
+ with open(glb_path, 'wb') as f:
104
+ f.write(export_glb(mesh))
105
+ return glb_path
106
+
107
+ # --------- Gradio Interface ---------
108
+ css = """
109
+ h1 { text-align: center; }
110
+ """
111
+ PLACEHOLDER = (
112
+ "<div style='padding:30px;text-align:center;display:flex;flex-direction:column;align-items:center;'>"
113
+ "<h1 style='font-size:28px;opacity:0.55;'>Qwen2.5-VL Local Chat</h1>"
114
+ "<p style='font-size:18px;opacity:0.65;'>Ask anything or generate images!</p></div>"
115
+ )
116
+
117
+ demo = gr.Blocks(css=css)
118
+ with demo:
119
+ gr.Markdown("## Local Qwen2.5-VL + 3D Mesh Visualization")
120
+ with gr.Row():
121
+ with gr.Column(scale=3):
122
+ chatbot = gr.Chatbot(label="Qwen2.5-VL-Local", height=450, placeholder=PLACEHOLDER)
123
+ gr.ChatInterface(
124
+ fn=chat_qwen_vl,
125
+ chatbot=chatbot,
126
+ temperature=0.7,
127
+ max_new_tokens=1024,
128
+ examples=[
129
+ ["Hello, what is Qwen2.5-VL?"],
130
+ ["Generate a simple Python function to sort a list."],
131
+ ]
132
+ )
133
+ with gr.Column(scale=2):
134
+ output_model = gr.Model3D(label="3D Mesh Visualization", interactive=False)
135
+ mesh_input = gr.Textbox(
136
+ label="3D Mesh OBJ Input",
137
+ placeholder="Paste your 3D mesh in OBJ format here...",
138
+ lines=5
139
+ )
140
+ visualize_btn = gr.Button("Visualize with Gradient Color")
141
+ visualize_btn.click(
142
+ fn=apply_gradient_color,
143
+ inputs=[mesh_input],
144
+ outputs=[output_model]
145
+ )
146
+
147
+ if __name__ == "__main__":
148
+ demo.launch()