YongdongWang commited on
Commit
0af933c
·
verified ·
1 Parent(s): e70a7f9

Force update Space with optimized robot planning interface

Browse files
Files changed (3) hide show
  1. README.md +10 -25
  2. app.py +79 -87
  3. requirements.txt +2 -0
README.md CHANGED
@@ -8,38 +8,23 @@ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: llama3.1
 
11
  ---
12
 
13
- # Robot Task Planning - Llama 3.1 8B
14
 
15
- This Space demonstrates a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique.
16
-
17
- The model converts natural language commands into structured task sequences for construction robots like excavators and dump trucks.
18
 
19
  ## Model
20
-
21
- The model is available at: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
22
 
23
  ## Features
24
-
25
- - **Robot Command Processing**: Convert natural language to structured robot tasks
26
- - **Multi-Robot Coordination**: Handle complex scenarios with multiple excavators and dump trucks
27
- - **Task Dependencies**: Generate proper task sequences with dependencies
28
- - **Real-time Planning**: Instant task generation powered by Gradio
29
 
30
  ## Usage
 
31
 
32
- Input natural language robot commands like "Deploy Excavator 1 to Soil Area 1" and the model will generate structured task sequences in JSON format for robot execution.
33
-
34
- ## Technical Details
35
-
36
- - **Base Model**: meta-llama/Llama-3.1-8B
37
- - **Fine-tuning**: QLoRA (4-bit quantization + LoRA)
38
- - **Interface**: Gradio
39
- - **Hosting**: HuggingFace Spaces
40
- - **Input**: Natural language robot commands
41
- - **Output**: Structured JSON task sequences
42
-
43
- ## Performance
44
-
45
- ⚠️ **Note**: Model loading may take 3-5 minutes on first startup due to the large model size and quantization process.
 
8
  app_file: app.py
9
  pinned: false
10
  license: llama3.1
11
+ hardware: t4-medium
12
  ---
13
 
14
+ # 🤖 Robot Task Planning - Llama 3.1 8B
15
 
16
+ Fine-tuned Llama 3.1 8B model for robot task planning using QLoRA technique.
 
 
17
 
18
  ## Model
19
+ [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
 
20
 
21
  ## Features
22
+ - Natural language to robot task conversion
23
+ - Multi-robot coordination
24
+ - Real-time task generation
25
+ - Optimized with 4-bit quantization
 
26
 
27
  ## Usage
28
+ Input robot commands and get structured task sequences for excavators, dump trucks, and other construction robots.
29
 
30
+ Loading time: ~3-5 minutes on first startup.
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  from peft import PeftModel
5
  import warnings
 
6
  warnings.filterwarnings("ignore")
7
 
8
  # 模型配置
@@ -23,7 +24,11 @@ def load_model():
23
  )
24
 
25
  # 加载分词器
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
 
 
 
 
27
  if tokenizer.pad_token is None:
28
  tokenizer.pad_token = tokenizer.eos_token
29
 
@@ -33,11 +38,16 @@ def load_model():
33
  quantization_config=bnb_config,
34
  device_map="auto",
35
  torch_dtype=torch.float16,
36
- trust_remote_code=True
 
37
  )
38
 
39
  # 加载 LoRA 适配器
40
- model = PeftModel.from_pretrained(base_model, LORA_MODEL)
 
 
 
 
41
  model.eval()
42
 
43
  print("✅ Model loaded successfully!")
@@ -47,28 +57,47 @@ def load_model():
47
  print(f"❌ Model loading failed: {load_error}")
48
  return None, None
49
 
50
- # 全局变量存储模型
51
  model = None
52
  tokenizer = None
 
53
 
54
  def initialize_model():
55
- """初始化模型 - 延迟加载"""
56
- global model, tokenizer
57
- if model is None or tokenizer is None:
 
 
 
 
 
 
 
 
58
  model, tokenizer = load_model()
59
- return model is not None and tokenizer is not None
 
 
60
 
61
  def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
62
  """生成回复"""
63
  if not initialize_model():
64
- return "❌ Model not loaded. Please check the logs or try again."
 
 
 
65
 
66
  try:
67
- # 格式化输入 - 移除多余的字符串插值
68
- formatted_prompt = prompt.strip()
69
 
70
  # 编码输入
71
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
72
 
73
  # 生成回复
74
  with torch.no_grad():
@@ -82,19 +111,18 @@ def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
82
  eos_token_id=tokenizer.eos_token_id,
83
  repetition_penalty=1.1,
84
  early_stopping=True,
 
85
  )
86
 
87
  # 解码输出
88
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
 
90
- # 移除原始输入,只保留生成的部分
91
- if len(response) > len(formatted_prompt):
 
 
92
  response = response[len(formatted_prompt):].strip()
93
 
94
- # 如果回复包含特殊标记,进行清理
95
- if "Assistant:" in response:
96
- response = response.split("Assistant:")[-1].strip()
97
-
98
  return response if response else "❌ No response generated. Please try again."
99
 
100
  except Exception as generation_error:
@@ -115,104 +143,68 @@ def chat_interface(message, history, max_tokens, temperature, top_p):
115
  return history, ""
116
 
117
  # 创建 Gradio 应用
118
- with gr.Blocks(title="Robot Task Planning - Llama 3.1 8B", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
119
  gr.Markdown("""
120
  # 🤖 Llama 3.1 8B - Robot Task Planning
121
 
122
- This is a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique.
123
-
124
- **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
125
 
126
  **Model**: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
127
 
128
- ⚠️ **Note**: Model loading may take a few minutes on first startup.
129
  """)
130
 
131
  with gr.Row():
132
  with gr.Column(scale=3):
133
  chatbot = gr.Chatbot(
134
- label="Task Planning Results",
135
- height=400,
136
- show_label=True,
137
- container=True,
138
- bubble_full_width=False
139
  )
140
 
141
  msg = gr.Textbox(
142
  label="Robot Command",
143
- placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...",
144
- lines=2,
145
- max_lines=5,
146
- show_label=True,
147
- container=True
148
  )
149
 
150
  with gr.Row():
151
- send_btn = gr.Button("Generate Tasks", variant="primary", size="sm")
152
- clear_btn = gr.Button("Clear", variant="secondary", size="sm")
153
 
154
  with gr.Column(scale=1):
155
- gr.Markdown("### ⚙️ Generation Settings")
156
 
157
- max_tokens = gr.Slider(
158
- minimum=50,
159
- maximum=500,
160
- value=200,
161
- step=10,
162
- label="Max Tokens",
163
- info="Maximum number of tokens to generate"
164
- )
165
-
166
- temperature = gr.Slider(
167
- minimum=0.1,
168
- maximum=2.0,
169
- value=0.7,
170
- step=0.1,
171
- label="Temperature",
172
- info="Controls randomness (lower = more focused)"
173
- )
174
-
175
- top_p = gr.Slider(
176
- minimum=0.1,
177
- maximum=1.0,
178
- value=0.9,
179
- step=0.05,
180
- label="Top-p",
181
- info="Nucleus sampling threshold"
182
- )
183
 
184
- # 示例对话
185
  gr.Examples(
186
  examples=[
187
  ["Deploy Excavator 1 to Soil Area 1 for excavation."],
188
- ["Send Dump Truck 1 to collect material, then unload at storage area."],
189
- ["Move all robots to avoid Puddle 1 after inspection."],
190
- ["Deploy multiple excavators to different soil areas simultaneously."],
191
- ["Coordinate dump trucks to transport materials from excavation site to storage."],
192
- ["Send robot to inspect rock area, then avoid with all other robots."],
193
- ["Return all robots to start position after completing tasks."],
194
  ],
195
  inputs=msg,
196
- label="💡 Example Robot Commands"
197
  )
198
 
199
  # 事件处理
200
- msg.submit(
201
- chat_interface,
202
- inputs=[msg, chatbot, max_tokens, temperature, top_p],
203
- outputs=[chatbot, msg]
204
- )
205
-
206
- send_btn.click(
207
- chat_interface,
208
- inputs=[msg, chatbot, max_tokens, temperature, top_p],
209
- outputs=[chatbot, msg]
210
- )
211
-
212
- clear_btn.click(
213
- lambda: ([], ""),
214
- outputs=[chatbot, msg]
215
- )
216
 
217
  if __name__ == "__main__":
218
- demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  from peft import PeftModel
5
  import warnings
6
+ import os
7
  warnings.filterwarnings("ignore")
8
 
9
  # 模型配置
 
24
  )
25
 
26
  # 加载分词器
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ MODEL_NAME,
29
+ use_fast=False,
30
+ trust_remote_code=True
31
+ )
32
  if tokenizer.pad_token is None:
33
  tokenizer.pad_token = tokenizer.eos_token
34
 
 
38
  quantization_config=bnb_config,
39
  device_map="auto",
40
  torch_dtype=torch.float16,
41
+ trust_remote_code=True,
42
+ low_cpu_mem_usage=True
43
  )
44
 
45
  # 加载 LoRA 适配器
46
+ model = PeftModel.from_pretrained(
47
+ base_model,
48
+ LORA_MODEL,
49
+ torch_dtype=torch.float16
50
+ )
51
  model.eval()
52
 
53
  print("✅ Model loaded successfully!")
 
57
  print(f"❌ Model loading failed: {load_error}")
58
  return None, None
59
 
60
+ # 全局变量
61
  model = None
62
  tokenizer = None
63
+ model_loading = False
64
 
65
  def initialize_model():
66
+ """初始化模型"""
67
+ global model, tokenizer, model_loading
68
+
69
+ if model is not None and tokenizer is not None:
70
+ return True
71
+
72
+ if model_loading:
73
+ return False
74
+
75
+ model_loading = True
76
+ try:
77
  model, tokenizer = load_model()
78
+ return model is not None and tokenizer is not None
79
+ finally:
80
+ model_loading = False
81
 
82
  def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
83
  """生成回复"""
84
  if not initialize_model():
85
+ if model_loading:
86
+ return "🔄 Model is loading, please wait a few minutes and try again..."
87
+ else:
88
+ return "❌ Model failed to load. Please check the Space logs."
89
 
90
  try:
91
+ # 格式化输入
92
+ formatted_prompt = f"### Human: {prompt.strip()}\n### Assistant:"
93
 
94
  # 编码输入
95
+ inputs = tokenizer(
96
+ formatted_prompt,
97
+ return_tensors="pt",
98
+ truncation=True,
99
+ max_length=2048
100
+ ).to(model.device)
101
 
102
  # 生成回复
103
  with torch.no_grad():
 
111
  eos_token_id=tokenizer.eos_token_id,
112
  repetition_penalty=1.1,
113
  early_stopping=True,
114
+ no_repeat_ngram_size=3
115
  )
116
 
117
  # 解码输出
118
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
119
 
120
+ # 提取生成的部分
121
+ if "### Assistant:" in response:
122
+ response = response.split("### Assistant:")[-1].strip()
123
+ elif len(response) > len(formatted_prompt):
124
  response = response[len(formatted_prompt):].strip()
125
 
 
 
 
 
126
  return response if response else "❌ No response generated. Please try again."
127
 
128
  except Exception as generation_error:
 
143
  return history, ""
144
 
145
  # 创建 Gradio 应用
146
+ with gr.Blocks(
147
+ title="Robot Task Planning - Llama 3.1 8B",
148
+ theme=gr.themes.Soft(),
149
+ css="""
150
+ .gradio-container {
151
+ max-width: 1200px;
152
+ margin: auto;
153
+ }
154
+ """
155
+ ) as demo:
156
  gr.Markdown("""
157
  # 🤖 Llama 3.1 8B - Robot Task Planning
158
 
159
+ Fine-tuned version of Meta's Llama 3.1 8B for **robot task planning** using QLoRA.
 
 
160
 
161
  **Model**: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
162
 
163
+ ⚠️ **First load takes 3-5 minutes**
164
  """)
165
 
166
  with gr.Row():
167
  with gr.Column(scale=3):
168
  chatbot = gr.Chatbot(
169
+ label="🤖 Task Planning Results",
170
+ height=500,
171
+ show_copy_button=True
 
 
172
  )
173
 
174
  msg = gr.Textbox(
175
  label="Robot Command",
176
+ placeholder="e.g., 'Deploy Excavator 1 to Soil Area 1'...",
177
+ lines=2
 
 
 
178
  )
179
 
180
  with gr.Row():
181
+ send_btn = gr.Button("🚀 Generate", variant="primary")
182
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
183
 
184
  with gr.Column(scale=1):
185
+ gr.Markdown("### ⚙️ Settings")
186
 
187
+ max_tokens = gr.Slider(50, 500, 200, label="Max Tokens")
188
+ temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
189
+ top_p = gr.Slider(0.1, 1.0, 0.9, step=0.05, label="Top-p")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ # 示例
192
  gr.Examples(
193
  examples=[
194
  ["Deploy Excavator 1 to Soil Area 1 for excavation."],
195
+ ["Send Dump Truck 1 to collect material and unload at storage."],
196
+ ["Move all robots to avoid dangerous Puddle 1."],
197
+ ["Coordinate multiple excavators across different areas."],
198
+ ["Create evacuation sequence for all robots."],
 
 
199
  ],
200
  inputs=msg,
201
+ label="💡 Try these examples"
202
  )
203
 
204
  # 事件处理
205
+ msg.submit(chat_interface, [msg, chatbot, max_tokens, temperature, top_p], [chatbot, msg])
206
+ send_btn.click(chat_interface, [msg, chatbot, max_tokens, temperature, top_p], [chatbot, msg])
207
+ clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg])
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -5,3 +5,5 @@ peft==0.7.1
5
  bitsandbytes==0.41.3
6
  accelerate==0.24.1
7
  scipy==1.11.4
 
 
 
5
  bitsandbytes==0.41.3
6
  accelerate==0.24.1
7
  scipy==1.11.4
8
+ sentencepiece
9
+ protobuf