yejunliang23 commited on
Commit
6af08b9
·
unverified ·
1 Parent(s): a19e6cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -130
app.py CHANGED
@@ -4,27 +4,119 @@ from threading import Thread
4
  import gradio as gr
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
6
  from qwen_vl_utils import process_vision_info
7
-
8
- # 3D mesh dependencies
9
  import trimesh
10
  from trimesh.exchange.gltf import export_glb
11
  import numpy as np
12
  import tempfile
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- DESCRIPTION = '''
16
- <div>
17
- <h1 style="text-align: center;">LLaMA-Mesh</h1>
18
- <div>
19
- <a style="display:inline-block" href="https://research.nvidia.com/labs/toronto-ai/LLaMA-Mesh/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
20
- <a style="display:inline-block; margin-left: .5em" href="https://github.com/nv-tlabs/LLaMA-Mesh"><img src='https://img.shields.io/github/stars/nv-tlabs/LLaMA-Mesh?style=social'/></a>
21
- </div>
22
- <p>LLaMA-Mesh: Unifying 3D Mesh Generation with Language Models.<a style="display:inline-block" href="https://research.nvidia.com/labs/toronto-ai/LLaMA-Mesh/">[Project Page]</a> <a style="display:inline-block" href="https://github.com/nv-tlabs/LLaMA-Mesh">[Code]</a></p>
23
- <p> Notice: (1) This demo supports up to 4096 tokens due to computational limits, while our full model supports 8k tokens. This limitation may result in incomplete generated meshes. To experience the full 8k token context, please run our model locally.</p>
24
- <p>(2) We only support generating a single mesh per dialog round. To generate another mesh, click the "clear" button and start a new dialog.</p>
25
- <p>(3) If the LLM refuses to generate a 3D mesh, try adding more explicit instructions to the prompt, such as "create a 3D model of a table <strong>in OBJ format</strong>." A more effective approach is to request the mesh generation at the start of the dialog.</p>
26
- </div>
27
- '''
28
  # --------- Configuration & Model Loading ---------
29
  MODEL_DIR = "Qwen/Qwen2.5-VL-3B-Instruct"
30
  # Load processor, tokenizer, model for Qwen2.5-VL
@@ -35,90 +127,43 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
35
  trust_remote_code=True
36
  )
37
  processor = AutoProcessor.from_pretrained(MODEL_DIR)
38
- terminators = [tokenizer.eos_token_id]
 
39
 
40
- def chat_qwen_vl(message: str, history: list, temperature: float = 0.1, max_new_tokens: int = 1024):
41
- # —— 原有多模态输入构造 —— #
42
  messages = [
43
  {
44
  "role": "user",
45
  "content": [
46
- {"type": "text", "text": message},
47
  ],
48
  }
49
  ]
 
50
  text = processor.apply_chat_template(
51
- messages, tokenize=False, add_generation_prompt=True
52
- )
53
- print(text)
54
  image_inputs, video_inputs = process_vision_info(messages)
55
- input_ids = processor(
56
- text=[text],
57
- images=image_inputs,
58
- videos=video_inputs,
59
- padding=False,
60
- return_tensors="pt"
61
- ).to(model.device)
62
-
63
- # 2. 把 streamer 和生成参数一起传给 model.generate
64
  streamer = TextIteratorStreamer(
65
- processor,
66
- timeout=100.0,
67
- skip_prompt=True,
68
- skip_special_tokens=True
69
- )
70
- #print(input_ids)
71
- generate_kwargs = dict(
72
- input_ids= input_ids["input_ids"],
73
- streamer=streamer,
74
- max_new_tokens=max_new_tokens,
75
- do_sample=True,
76
- temperature=temperature,
77
- eos_token_id=terminators,
78
- top_k=1024,
79
- top_p=0.1
80
- )
81
 
82
- # 4. 后台线程启动生成
83
- Thread(target=model.generate, kwargs=generate_kwargs).start()
 
 
 
 
 
84
 
85
- # 5. 主线程读取 streamer 并 yield
86
  buffer = []
87
  for chunk in streamer:
88
- print(chunk)
89
  buffer.append(chunk)
90
- # 每次新到一个片段,就拼接并返回给前端
91
  yield "".join(buffer)
92
-
93
-
94
- # --------- 3D Mesh Coloring Function ---------
95
- def apply_gradient_color(mesh_text: str) -> str:
96
- """
97
- Apply a Y-axis-based gradient RGBA color to OBJ mesh text and export as GLB.
98
- """
99
- # Write OBJ to temp file
100
- tmp = tempfile.NamedTemporaryFile(suffix=".obj", delete=False)
101
- tmp.write(mesh_text.encode('utf-8'))
102
- tmp.flush()
103
- tmp.close()
104
-
105
- mesh = trimesh.load_mesh(tmp.name, file_type='obj')
106
- vertices = mesh.vertices
107
- ys = vertices[:, 1]
108
- y_norm = (ys - ys.min()) / (ys.max() - ys.min())
109
-
110
- colors = np.zeros((len(vertices), 4))
111
- colors[:, 0] = y_norm
112
- colors[:, 2] = 1 - y_norm
113
- colors[:, 3] = 1.0
114
- mesh.visual.vertex_colors = colors
115
-
116
- glb_path = tmp.name.replace('.obj', '.glb')
117
- with open(glb_path, 'wb') as f:
118
- f.write(export_glb(mesh))
119
- return glb_path
120
-
121
- # --------- Gradio Interface ---------
122
  css = """
123
  h1 { text-align: center; }
124
  """
@@ -128,53 +173,27 @@ PLACEHOLDER = (
128
  "<p style='font-size:18px;opacity:0.65;'>Ask anything or generate images!</p></div>"
129
  )
130
 
131
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
132
-
133
- with gr.Blocks(fill_height=True, css=css) as demo:
134
- with gr.Column():
135
- gr.Markdown(DESCRIPTION)
136
- # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
137
- with gr.Row():
138
- with gr.Column(scale=3):
139
- gr.ChatInterface(
140
- fn=chat_qwen_vl,
141
- chatbot=chatbot,
142
- fill_height=True,
143
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
144
- additional_inputs=[
145
- gr.Slider(minimum=0,
146
- maximum=1,
147
- step=0.1,
148
- value=0.9,
149
- label="Temperature",
150
- interactive = False,
151
- render=False),
152
- gr.Slider(minimum=128,
153
- maximum=4096,
154
- step=1,
155
- value=4096,
156
- label="Max new tokens",
157
- interactive = False,
158
- render=False),
159
- ],
160
- examples=[
161
- ['Create a 3D model of a wooden hammer'],
162
- ['Create a 3D model of a pyramid in obj format'],
163
- ['Create a 3D model of a cabinet.'],
164
- ['Create a low poly 3D model of a coffe cup'],
165
- ['Create a 3D model of a table.'],
166
- ["Create a low poly 3D model of a tree."],
167
- ['Write a python code for sorting.'],
168
- ['How to setup a human base on Mars? Give short answer.'],
169
- ['Explain theory of relativity to me like I’m 8 years old.'],
170
- ['What is 9,000 * 9,000?'],
171
- ['Create a 3D model of a soda can.'],
172
- ['Create a 3D model of a sword.'],
173
- ['Create a 3D model of a wooden barrel'],
174
- ['Create a 3D model of a chair.']
175
- ],
176
- cache_examples=False,
177
- )
178
 
179
  if __name__ == "__main__":
180
- demo.launch()
 
4
  import gradio as gr
5
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor,TextIteratorStreamer,AutoTokenizer
6
  from qwen_vl_utils import process_vision_info
 
 
7
  import trimesh
8
  from trimesh.exchange.gltf import export_glb
9
  import numpy as np
10
  import tempfile
11
 
12
+ def predict(_chatbot, task_history):
13
+ chat_query = _chatbot[-1][0]
14
+ query = task_history[-1][0]
15
+ if len(chat_query) == 0:
16
+ _chatbot.pop()
17
+ task_history.pop()
18
+ return _chatbot
19
+ print("User: " + _parse_text(query))
20
+ history_cp = copy.deepcopy(task_history)
21
+ full_response = ""
22
+ messages = []
23
+ content = []
24
+ for q, a in history_cp:
25
+ if isinstance(q, (tuple, list)):
26
+ if is_video_file(q[0]):
27
+ content.append({'video': f'file://{q[0]}'})
28
+ else:
29
+ content.append({'image': f'file://{q[0]}'})
30
+ else:
31
+ content.append({'text': q})
32
+ messages.append({'role': 'user', 'content': content})
33
+ messages.append({'role': 'assistant', 'content': [{'text': a}]})
34
+ content = []
35
+ messages.pop()
36
+
37
+ messages = _transform_messages(messages)
38
+ text = processor.apply_chat_template(
39
+ messages, tokenize=False, add_generation_prompt=True)
40
+ image_inputs, video_inputs = process_vision_info(messages)
41
+ inputs = processor(text=[text], images=image_inputs,
42
+ videos=video_inputs, padding=True, return_tensors='pt')
43
+ inputs = inputs.to(model.device)
44
+
45
+ streamer = TextIteratorStreamer(
46
+ tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
47
+
48
+ gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
49
+
50
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
51
+ thread.start()
52
+
53
+ #for new_text in streamer:
54
+ # yield new_text
55
+
56
+ buffer = []
57
+ for chunk in streamer:
58
+ buffer.append(chunk)
59
+ yield "".join(buffer)
60
+
61
+
62
+ def regenerate(_chatbot, task_history):
63
+ if not task_history:
64
+ return _chatbot
65
+ item = task_history[-1]
66
+ if item[1] is None:
67
+ return _chatbot
68
+ task_history[-1] = (item[0], None)
69
+ chatbot_item = _chatbot.pop(-1)
70
+ if chatbot_item[0] is None:
71
+ _chatbot[-1] = (_chatbot[-1][0], None)
72
+ else:
73
+ _chatbot.append((chatbot_item[0], None))
74
+ _chatbot_gen = predict(_chatbot, task_history)
75
+ for _chatbot in _chatbot_gen:
76
+ yield _chatbot
77
+
78
+ def add_text(history, task_history, text):
79
+ task_text = text
80
+ history = history if history is not None else []
81
+ task_history = task_history if task_history is not None else []
82
+ history = history + [(_parse_text(text), None)]
83
+ task_history = task_history + [(task_text, None)]
84
+ return history, task_history, ""
85
+
86
+ def add_file(history, task_history, file):
87
+ history = history if history is not None else []
88
+ task_history = task_history if task_history is not None else []
89
+ history = history + [((file.name,), None)]
90
+ task_history = task_history + [((file.name,), None)]
91
+ return history, task_history
92
+
93
+ def reset_user_input():
94
+ return gr.update(value="")
95
+
96
+ def reset_state(task_history):
97
+ task_history.clear()
98
+ return []
99
+
100
+ def _transform_messages(original_messages):
101
+ transformed_messages = []
102
+ for message in original_messages:
103
+ new_content = []
104
+ for item in message['content']:
105
+ if 'image' in item:
106
+ new_item = {'type': 'image', 'image': item['image']}
107
+ elif 'text' in item:
108
+ new_item = {'type': 'text', 'text': item['text']}
109
+ elif 'video' in item:
110
+ new_item = {'type': 'video', 'video': item['video']}
111
+ else:
112
+ continue
113
+ new_content.append(new_item)
114
+
115
+ new_message = {'role': message['role'], 'content': new_content}
116
+ transformed_messages.append(new_message)
117
+
118
+ return transformed_messages
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  # --------- Configuration & Model Loading ---------
121
  MODEL_DIR = "Qwen/Qwen2.5-VL-3B-Instruct"
122
  # Load processor, tokenizer, model for Qwen2.5-VL
 
127
  trust_remote_code=True
128
  )
129
  processor = AutoProcessor.from_pretrained(MODEL_DIR)
130
+ tokenizer = processor.tokenizer
131
+ #terminators = [tokenizer.eos_token_id]
132
 
133
+ def chat_qwen_vl(messages: str, history: list, temperature: float = 0.1, max_new_tokens: int = 1024):
 
134
  messages = [
135
  {
136
  "role": "user",
137
  "content": [
138
+ {"type": "text", "text": messages},
139
  ],
140
  }
141
  ]
142
+ messages = _transform_messages(messages)
143
  text = processor.apply_chat_template(
144
+ messages, tokenize=False, add_generation_prompt=True)
 
 
145
  image_inputs, video_inputs = process_vision_info(messages)
146
+ inputs = processor(text=[text], images=image_inputs,
147
+ videos=video_inputs, padding=True, return_tensors='pt')
148
+ inputs = inputs.to(model.device)
149
+
 
 
 
 
 
150
  streamer = TextIteratorStreamer(
151
+ tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}
154
+
155
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
156
+ thread.start()
157
+
158
+ #for new_text in streamer:
159
+ # yield new_text
160
 
 
161
  buffer = []
162
  for chunk in streamer:
 
163
  buffer.append(chunk)
 
164
  yield "".join(buffer)
165
+
166
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  css = """
168
  h1 { text-align: center; }
169
  """
 
173
  "<p style='font-size:18px;opacity:0.65;'>Ask anything or generate images!</p></div>"
174
  )
175
 
176
+ with gr.Blocks() as demo:
177
+ gr.Markdown("""<center><font size=3> ShapeLLM-7B Demo </center>""")
178
+
179
+ chatbot = gr.Chatbot(label='ShapeLLM-4o', elem_classes="control-height", height=500)
180
+ query = gr.Textbox(lines=2, label='Input')
181
+ task_history = gr.State([])
182
+
183
+ with gr.Row():
184
+ addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image", "video"])
185
+ submit_btn = gr.Button("🚀 Submit (发送)")
186
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
187
+ empty_bin = gr.Button("🧹 Clear History (清除历史)")
188
+
189
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
190
+ predict, [chatbot, task_history], [chatbot], show_progress=True
191
+ )
192
+ submit_btn.click(reset_user_input, [], [query])
193
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
194
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
195
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
196
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  if __name__ == "__main__":
199
+ demo.launch()