liuyiyang01 commited on
Commit
6964b3e
·
1 Parent(s): 4f482ac

oss support

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_utils.py +30 -14
app.py CHANGED
@@ -24,7 +24,7 @@ header_html = """
24
  </div>
25
  <div style="display: flex; gap: 15px; align-items: center;">
26
  <a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
27
- <img src="assets/github-mark.png" alt="GitHub" style="height: 30px;">
28
  </a>
29
  <a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
30
  <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;">
 
24
  </div>
25
  <div style="display: flex; gap: 15px; align-items: center;">
26
  <a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
27
+ <img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px;">
28
  </a>
29
  <a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
30
  <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;">
app_utils.py CHANGED
@@ -19,7 +19,7 @@ os.makedirs(TMP_ROOT, exist_ok=True)
19
 
20
 
21
  # 后端API配置(可配置化)
22
- BACKEND_URL = os.getenv("BACKEND_URL")
23
  API_ENDPOINTS = {
24
  "submit_task": f"{BACKEND_URL}/predict/video",
25
  "query_status": f"{BACKEND_URL}/predict/task",
@@ -183,6 +183,13 @@ def download_oss_file(oss_path: str, local_path: str):
183
  """从OSS下载文件到本地"""
184
  bucket.get_object_to_file(oss_path, local_path)
185
 
 
 
 
 
 
 
 
186
  def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
187
  """
188
  流式输出仿真结果,从OSS读取图片
@@ -190,8 +197,9 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
190
  参数:
191
  result_folder: OSS上包含生成图片的文件夹路径
192
  task_id: 后端任务ID用于状态查询
193
- fps: 输出视频的帧率
194
  request: Gradio请求对象
 
 
195
 
196
  生成:
197
  生成的视频文件路径 (分段输出)
@@ -204,12 +212,13 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
204
  processed_files = set()
205
  width, height = 0, 0
206
  last_status_check = 0
207
- status_check_interval = 5 # 每5秒检查一次后端状态
208
  max_time = 240
209
 
210
  # 创建临时目录存储下载的图片
211
  user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
212
- local_image_dir = os.path.join(user_dir, "tasks", "images")
 
213
  os.makedirs(local_image_dir, exist_ok=True)
214
 
215
  while max_time > 0:
@@ -219,7 +228,7 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
219
  # 定期检查后端状态
220
  if current_time - last_status_check > status_check_interval:
221
  status = get_task_status(task_id)
222
- print("status: ", status)
223
  if status.get("status") == "completed":
224
  # 确保处理完所有已生成的图片
225
  process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
@@ -275,7 +284,7 @@ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height:
275
  """创建视频片段"""
276
  user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
277
  os.makedirs(user_dir, exist_ok=True)
278
- video_chunk_path = os.path.join(user_dir, "tasks/video_chunk")
279
  os.makedirs(video_chunk_path, exist_ok=True)
280
  segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
281
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -364,7 +373,7 @@ def get_task_status(task_id: str) -> dict:
364
  )
365
  return response.json()
366
  except Exception as e:
367
- return {"status": "error", "message": str(e)}
368
 
369
  def terminate_task(task_id: str) -> Optional[dict]:
370
  """
@@ -431,6 +440,7 @@ def run_simulation(
431
  # 记录用户提交
432
  user_ip = request.client.host if request else "unknown"
433
  session_id = request.session_hash
 
434
 
435
  if not is_request_allowed(user_ip):
436
  log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
@@ -454,19 +464,20 @@ def run_simulation(
454
  status = get_task_status(task_id)
455
  print("first status: ", status)
456
  result_folder = status.get("result", "")
 
457
  except Exception as e:
458
  log_submission(scene, prompt, model, max_step, user_ip, str(e))
459
  raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
460
 
461
 
462
- if not os.path.exists(result_folder):
463
- log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist")
464
- raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
465
 
466
 
467
  # 流式输出视频片段
468
  try:
469
- for video_path in stream_simulation_results(result_folder, task_id):
470
  if video_path:
471
  yield video_path, history
472
  except Exception as e:
@@ -477,9 +488,14 @@ def run_simulation(
477
  status = get_task_status(task_id)
478
  print("status: ", status)
479
  if status.get("status") == "completed":
480
- video_path = os.path.join(status.get("result"), "manipulation.mp4")
481
- print("video_path: ", video_path)
482
- video_path = convert_to_h264(video_path)
 
 
 
 
 
483
 
484
  # 创建新的历史记录条目
485
  new_entry = {
 
19
 
20
 
21
  # 后端API配置(可配置化)
22
+ BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
23
  API_ENDPOINTS = {
24
  "submit_task": f"{BACKEND_URL}/predict/video",
25
  "query_status": f"{BACKEND_URL}/predict/task",
 
183
  """从OSS下载文件到本地"""
184
  bucket.get_object_to_file(oss_path, local_path)
185
 
186
+ def oss_file_exists(oss_path):
187
+ try:
188
+ # Assuming you have an OSS bucket object
189
+ return bucket.object_exists(oss_path)
190
+ except Exception as e:
191
+ print(f"Error checking if file exists in OSS: {str(e)}")
192
+ return False
193
  def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
194
  """
195
  流式输出仿真结果,从OSS读取图片
 
197
  参数:
198
  result_folder: OSS上包含生成图片的文件夹路径
199
  task_id: 后端任务ID用于状态查询
 
200
  request: Gradio请求对象
201
+ fps: 输出视频的帧率
202
+
203
 
204
  生成:
205
  生成的视频文件路径 (分段输出)
 
212
  processed_files = set()
213
  width, height = 0, 0
214
  last_status_check = 0
215
+ status_check_interval = 1 # 每5秒检查一次后端状态
216
  max_time = 240
217
 
218
  # 创建临时目录存储下载的图片
219
  user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
220
+ local_image_dir = os.path.join(user_dir, task_id, "tasks", "images")
221
+
222
  os.makedirs(local_image_dir, exist_ok=True)
223
 
224
  while max_time > 0:
 
228
  # 定期检查后端状态
229
  if current_time - last_status_check > status_check_interval:
230
  status = get_task_status(task_id)
231
+ print(str(request.session_hash), "status: ", status)
232
  if status.get("status") == "completed":
233
  # 确保处理完所有已生成的图片
234
  process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
 
284
  """创建视频片段"""
285
  user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
286
  os.makedirs(user_dir, exist_ok=True)
287
+ video_chunk_path = os.path.join(user_dir, "video_chunk")
288
  os.makedirs(video_chunk_path, exist_ok=True)
289
  segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
290
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
373
  )
374
  return response.json()
375
  except Exception as e:
376
+ return {"status": "error get_task_status", "message": str(e)}
377
 
378
  def terminate_task(task_id: str) -> Optional[dict]:
379
  """
 
440
  # 记录用户提交
441
  user_ip = request.client.host if request else "unknown"
442
  session_id = request.session_hash
443
+ user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
444
 
445
  if not is_request_allowed(user_ip):
446
  log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
 
464
  status = get_task_status(task_id)
465
  print("first status: ", status)
466
  result_folder = status.get("result", "")
467
+ result_folder = "gradio_demo/tasks/" + task_id
468
  except Exception as e:
469
  log_submission(scene, prompt, model, max_step, user_ip, str(e))
470
  raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
471
 
472
 
473
+ # if not os.path.exists(result_folder):
474
+ # log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist")
475
+ # raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
476
 
477
 
478
  # 流式输出视频片段
479
  try:
480
+ for video_path in stream_simulation_results(result_folder, task_id, request):
481
  if video_path:
482
  yield video_path, history
483
  except Exception as e:
 
488
  status = get_task_status(task_id)
489
  print("status: ", status)
490
  if status.get("status") == "completed":
491
+ # time.sleep(3)
492
+ oss_video_path = os.path.join(result_folder, "manipulation.mp4")
493
+ local_video_path = os.path.join(user_dir, task_id, "tasks", "manipulation.mp4")
494
+ download_oss_file(oss_video_path, local_video_path)
495
+ print("oss_video_path: ", oss_video_path)
496
+ print("local_video_path: ", local_video_path)
497
+
498
+ video_path = convert_to_h264(local_video_path)
499
 
500
  # 创建新的历史记录条目
501
  new_entry = {