zR commited on
Commit
1e0442b
·
1 Parent(s): 7b5683e
Files changed (3) hide show
  1. README.md +7 -1
  2. app.py +229 -0
  3. requirements.txt +19 -0
README.md CHANGED
@@ -10,4 +10,10 @@ pinned: false
10
  short_description: CogAgent1.5-Chat-Demo
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
10
  short_description: CogAgent1.5-Chat-Demo
11
  ---
12
 
13
+ ## Running the Model
14
+
15
+ 1. Install the required libraries
16
+
17
+ ```bash
18
+ pip install -r requirements.txt
19
+ ```
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import threading
5
+ import time
6
+ from datetime import datetime, timedelta
7
+
8
+ import torch
9
+ from threading import Thread, Event
10
+ from PIL import Image, ImageDraw
11
+ import gradio as gr
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ TextIteratorStreamer,
16
+ )
17
+ from typing import List
18
+ import spaces
19
+
20
+ stop_event = Event()
21
+
22
+ def delete_old_files():
23
+ while True:
24
+ now = datetime.now()
25
+ cutoff = now - timedelta(minutes=10)
26
+ directories = ["./output", "./gradio_tmp"]
27
+
28
+ for directory in directories:
29
+ for filename in os.listdir(directory):
30
+ file_path = os.path.join(directory, filename)
31
+ if os.path.isfile(file_path):
32
+ file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path))
33
+ if file_mtime < cutoff:
34
+ os.remove(file_path)
35
+ time.sleep(600)
36
+
37
+
38
+ threading.Thread(target=delete_old_files, daemon=True).start()
39
+
40
+
41
+ def draw_boxes_on_image(image: Image.Image, boxes: List[List[float]], save_path: str):
42
+ draw = ImageDraw.Draw(image)
43
+ for box in boxes:
44
+ x_min = int(box[0] * image.width)
45
+ y_min = int(box[1] * image.height)
46
+ x_max = int(box[2] * image.width)
47
+ y_max = int(box[3] * image.height)
48
+ draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)
49
+ image.save(save_path)
50
+
51
+
52
+ def preprocess_messages(history, img_path, platform_str, format_str):
53
+ history_step = []
54
+ for task, model_msg in history:
55
+ grounded_pattern = r"Grounded Operation:\s*(.*)"
56
+ matches_history = re.search(grounded_pattern, model_msg)
57
+ if matches_history:
58
+ grounded_operation = matches_history.group(1)
59
+ history_step.append(grounded_operation)
60
+
61
+ history_str = "\nHistory steps: "
62
+ if history_step:
63
+ for i, step in enumerate(history_step):
64
+ history_str += f"\n{i}. {step}"
65
+
66
+ if history:
67
+ task = history[-1][0]
68
+ else:
69
+ task = "No task provided"
70
+
71
+ query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
72
+ image = Image.open(img_path).convert("RGB")
73
+ return query, image
74
+
75
+
76
+ @spaces.GPU()
77
+ def predict(history, max_length, top_p, temperature, img_path, platform_str, format_str, output_dir):
78
+ # Reset the stop_event at the start of prediction
79
+ stop_event.clear()
80
+
81
+ # Remember history length before this round (for rollback if stopped)
82
+ prev_len = len(history)
83
+
84
+ query, image = preprocess_messages(history, img_path, platform_str, format_str)
85
+ model_inputs = tokenizer.apply_chat_template(
86
+ [{"role": "user", "image": image, "content": query}],
87
+ add_generation_prompt=True,
88
+ tokenize=True,
89
+ return_tensors="pt",
90
+ return_dict=True,
91
+ ).to(model.device)
92
+
93
+ streamer = TextIteratorStreamer(
94
+ tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True
95
+ )
96
+ generate_kwargs = {
97
+ "input_ids": model_inputs["input_ids"].to(model.device),
98
+ "attention_mask": model_inputs["attention_mask"].to(model.device),
99
+ "streamer": streamer,
100
+ "max_new_tokens": max_length,
101
+ "do_sample": True if temperature > 0.0 else False,
102
+ "top_p": top_p,
103
+ "temperature": temperature,
104
+ }
105
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
106
+ t.start()
107
+
108
+ for new_token in streamer:
109
+ # Check if stop event is set
110
+ if stop_event.is_set():
111
+ # Stop generation immediately
112
+ # Rollback the last round user input
113
+ while len(history) > prev_len:
114
+ history.pop()
115
+ yield history, None
116
+ return
117
+
118
+ if new_token:
119
+ history[-1][1] += new_token
120
+ yield history, None
121
+
122
+ # If finished without stop event
123
+ response = history[-1][1]
124
+ box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
125
+ matches = re.findall(box_pattern, response)
126
+ if matches:
127
+ boxes = [[int(x) / 1000 for x in match] for match in matches]
128
+ os.makedirs(output_dir, exist_ok=True)
129
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
130
+ round_num = sum(1 for (u, m) in history if u and m)
131
+ output_path = os.path.join(output_dir, f"{base_name}_{round_num}.png")
132
+ image = Image.open(img_path).convert("RGB")
133
+ draw_boxes_on_image(image, boxes, output_path)
134
+ yield history, output_path
135
+ else:
136
+ yield history, None
137
+
138
+
139
+ def user(task, history):
140
+ return "", history + [[task, ""]]
141
+
142
+
143
+ def undo_last_round(history, output_img):
144
+ if history:
145
+ history.pop()
146
+ return history, None
147
+
148
+
149
+ def clear_all_history():
150
+ return None, None
151
+
152
+
153
+ def stop_now():
154
+ stop_event.set()
155
+ return gr.update(), gr.update()
156
+
157
+
158
+ def main():
159
+ parser = argparse.ArgumentParser(description="CogAgent Gradio Demo")
160
+ parser.add_argument("--model_dir", default="THUDM/cogagent1.5-9b", help="Path or identifier of the model.")
161
+ parser.add_argument("--format_key", default="action_op_sensitive", help="Key to select the prompt format.")
162
+ parser.add_argument("--platform", default="Mac", help="Platform information string.")
163
+ parser.add_argument("--output_dir", default="outputs", help="Directory to save annotated images.")
164
+ args = parser.parse_args()
165
+
166
+ format_dict = {
167
+ "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
168
+ "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
169
+ "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
170
+ "status_action_op": "(Answer in Status-Action-Operation format.)",
171
+ "action_op": "(Answer in Action-Operation format.)"
172
+ }
173
+
174
+ if args.format_key not in format_dict:
175
+ raise ValueError(f"Invalid format_key. Available keys: {list(format_dict.keys())}")
176
+
177
+ global tokenizer, model
178
+ tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
179
+ model = AutoModelForCausalLM.from_pretrained(
180
+ args.model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto"
181
+ ).eval()
182
+
183
+ platform_str = f"(Platform: {args.platform})\n"
184
+ format_str = format_dict[args.format_key]
185
+
186
+ with gr.Blocks(analytics_enabled=False) as demo:
187
+ gr.HTML("<h1 align='center'>CogAgent1.5-9B Demo</h1>")
188
+ gr.HTML(
189
+ "<p align='center' style='color:red;'>This Demo is for learning and communication purposes only. Users must assume responsibility for the risks associated with AI-generated planning and operations.</p>")
190
+
191
+ with gr.Row():
192
+ img_path = gr.Image(label="Upload a Screenshot", type="filepath", height=400)
193
+ output_img = gr.Image(type="filepath", label="Annotated Image", height=400, interactive=False)
194
+
195
+ with gr.Row():
196
+ with gr.Column(scale=2):
197
+ chatbot = gr.Chatbot(height=300)
198
+ task = gr.Textbox(show_label=True, placeholder="Input...", label="Task")
199
+ submitBtn = gr.Button("Submit")
200
+ with gr.Column(scale=1):
201
+ max_length = gr.Slider(0, 8192, value=1024, step=1.0, label="Maximum length", interactive=True)
202
+ top_p = gr.Slider(0, 1, value=0.0, step=0.01, label="Top P", interactive=True)
203
+ temperature = gr.Slider(0.01, 1, value=0.0, step=0.01, label="Temperature", interactive=True)
204
+ undo_last_round_btn = gr.Button("Back to Last Round")
205
+ clear_history_btn = gr.Button("Clear All History")
206
+
207
+ # 添加红色的立刻中断按钮,点击后中断生成并回滚当前轮历史
208
+ stop_now_btn = gr.Button("Stop Now", variant="stop")
209
+
210
+ submitBtn.click(
211
+ user, [task, chatbot], [task, chatbot], queue=False
212
+ ).then(
213
+ predict,
214
+ [chatbot, max_length, top_p, temperature, img_path, gr.State(platform_str), gr.State(format_str),
215
+ gr.State(args.output_dir)],
216
+ [chatbot, output_img],
217
+ queue=True
218
+ )
219
+
220
+ undo_last_round_btn.click(undo_last_round, [chatbot, output_img], [chatbot, output_img], queue=False)
221
+ clear_history_btn.click(clear_all_history, None, [chatbot, output_img], queue=False)
222
+ stop_now_btn.click(stop_now, None, [chatbot, output_img], queue=False)
223
+
224
+ demo.queue()
225
+ demo.launch()
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.47.2
2
+ torch==2.5.0
3
+ torchvision==0.20.0
4
+ huggingface-hub>=0.25.1
5
+ sentencepiece>=0.2.0
6
+ jinja2>=3.1.4
7
+ pydantic>=2.9.2
8
+ timm>=1.0.9
9
+ tiktoken>=0.8.0
10
+ numpy==1.26.4
11
+ accelerate>=1.1.1
12
+ sentence_transformers>=3.1.1
13
+ gradio>=5.9.0
14
+ openai>=1.58.0
15
+ einops>=0.8.0
16
+ pillow>=10.4.0
17
+ sse-starlette>=2.1.3
18
+ bitsandbytes>=0.43.2
19
+ spaces>=0.31.1