liuyiyang01
commited on
Commit
·
6964b3e
1
Parent(s):
4f482ac
oss support
Browse files- app.py +1 -1
- app_utils.py +30 -14
app.py
CHANGED
@@ -24,7 +24,7 @@ header_html = """
|
|
24 |
</div>
|
25 |
<div style="display: flex; gap: 15px; align-items: center;">
|
26 |
<a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
27 |
-
<img src="
|
28 |
</a>
|
29 |
<a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
30 |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;">
|
|
|
24 |
</div>
|
25 |
<div style="display: flex; gap: 15px; align-items: center;">
|
26 |
<a href="https://github.com/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
27 |
+
<img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px;">
|
28 |
</a>
|
29 |
<a href="https://huggingface.co/InternRobotics" target="_blank" style="text-decoration: none; transition: transform 0.2s;" onmouseover="this.style.transform='scale(1.1)'" onmouseout="this.style.transform='scale(1)'">
|
30 |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="HuggingFace" style="height: 30px;">
|
app_utils.py
CHANGED
@@ -19,7 +19,7 @@ os.makedirs(TMP_ROOT, exist_ok=True)
|
|
19 |
|
20 |
|
21 |
# 后端API配置(可配置化)
|
22 |
-
BACKEND_URL = os.getenv("BACKEND_URL")
|
23 |
API_ENDPOINTS = {
|
24 |
"submit_task": f"{BACKEND_URL}/predict/video",
|
25 |
"query_status": f"{BACKEND_URL}/predict/task",
|
@@ -183,6 +183,13 @@ def download_oss_file(oss_path: str, local_path: str):
|
|
183 |
"""从OSS下载文件到本地"""
|
184 |
bucket.get_object_to_file(oss_path, local_path)
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
|
187 |
"""
|
188 |
流式输出仿真结果,从OSS读取图片
|
@@ -190,8 +197,9 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
|
|
190 |
参数:
|
191 |
result_folder: OSS上包含生成图片的文件夹路径
|
192 |
task_id: 后端任务ID用于状态查询
|
193 |
-
fps: 输出视频的帧率
|
194 |
request: Gradio请求对象
|
|
|
|
|
195 |
|
196 |
生成:
|
197 |
生成的视频文件路径 (分段输出)
|
@@ -204,12 +212,13 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
|
|
204 |
processed_files = set()
|
205 |
width, height = 0, 0
|
206 |
last_status_check = 0
|
207 |
-
status_check_interval =
|
208 |
max_time = 240
|
209 |
|
210 |
# 创建临时目录存储下载的图片
|
211 |
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
212 |
-
local_image_dir = os.path.join(user_dir, "tasks", "images")
|
|
|
213 |
os.makedirs(local_image_dir, exist_ok=True)
|
214 |
|
215 |
while max_time > 0:
|
@@ -219,7 +228,7 @@ def stream_simulation_results(result_folder: str, task_id: str, request: gr.Requ
|
|
219 |
# 定期检查后端状态
|
220 |
if current_time - last_status_check > status_check_interval:
|
221 |
status = get_task_status(task_id)
|
222 |
-
print("status: ", status)
|
223 |
if status.get("status") == "completed":
|
224 |
# 确保处理完所有已生成的图片
|
225 |
process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
|
@@ -275,7 +284,7 @@ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height:
|
|
275 |
"""创建视频片段"""
|
276 |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
277 |
os.makedirs(user_dir, exist_ok=True)
|
278 |
-
video_chunk_path = os.path.join(user_dir, "
|
279 |
os.makedirs(video_chunk_path, exist_ok=True)
|
280 |
segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
|
281 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
@@ -364,7 +373,7 @@ def get_task_status(task_id: str) -> dict:
|
|
364 |
)
|
365 |
return response.json()
|
366 |
except Exception as e:
|
367 |
-
return {"status": "error", "message": str(e)}
|
368 |
|
369 |
def terminate_task(task_id: str) -> Optional[dict]:
|
370 |
"""
|
@@ -431,6 +440,7 @@ def run_simulation(
|
|
431 |
# 记录用户提交
|
432 |
user_ip = request.client.host if request else "unknown"
|
433 |
session_id = request.session_hash
|
|
|
434 |
|
435 |
if not is_request_allowed(user_ip):
|
436 |
log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
|
@@ -454,19 +464,20 @@ def run_simulation(
|
|
454 |
status = get_task_status(task_id)
|
455 |
print("first status: ", status)
|
456 |
result_folder = status.get("result", "")
|
|
|
457 |
except Exception as e:
|
458 |
log_submission(scene, prompt, model, max_step, user_ip, str(e))
|
459 |
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
|
460 |
|
461 |
|
462 |
-
if not os.path.exists(result_folder):
|
463 |
-
|
464 |
-
|
465 |
|
466 |
|
467 |
# 流式输出视频片段
|
468 |
try:
|
469 |
-
for video_path in stream_simulation_results(result_folder, task_id):
|
470 |
if video_path:
|
471 |
yield video_path, history
|
472 |
except Exception as e:
|
@@ -477,9 +488,14 @@ def run_simulation(
|
|
477 |
status = get_task_status(task_id)
|
478 |
print("status: ", status)
|
479 |
if status.get("status") == "completed":
|
480 |
-
|
481 |
-
|
482 |
-
|
|
|
|
|
|
|
|
|
|
|
483 |
|
484 |
# 创建新的历史记录条目
|
485 |
new_entry = {
|
|
|
19 |
|
20 |
|
21 |
# 后端API配置(可配置化)
|
22 |
+
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:8000")
|
23 |
API_ENDPOINTS = {
|
24 |
"submit_task": f"{BACKEND_URL}/predict/video",
|
25 |
"query_status": f"{BACKEND_URL}/predict/task",
|
|
|
183 |
"""从OSS下载文件到本地"""
|
184 |
bucket.get_object_to_file(oss_path, local_path)
|
185 |
|
186 |
+
def oss_file_exists(oss_path):
|
187 |
+
try:
|
188 |
+
# Assuming you have an OSS bucket object
|
189 |
+
return bucket.object_exists(oss_path)
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Error checking if file exists in OSS: {str(e)}")
|
192 |
+
return False
|
193 |
def stream_simulation_results(result_folder: str, task_id: str, request: gr.Request, fps: int = 30):
|
194 |
"""
|
195 |
流式输出仿真结果,从OSS读取图片
|
|
|
197 |
参数:
|
198 |
result_folder: OSS上包含生成图片的文件夹路径
|
199 |
task_id: 后端任务ID用于状态查询
|
|
|
200 |
request: Gradio请求对象
|
201 |
+
fps: 输出视频的帧率
|
202 |
+
|
203 |
|
204 |
生成:
|
205 |
生成的视频文件路径 (分段输出)
|
|
|
212 |
processed_files = set()
|
213 |
width, height = 0, 0
|
214 |
last_status_check = 0
|
215 |
+
status_check_interval = 1 # 每5秒检查一次后端状态
|
216 |
max_time = 240
|
217 |
|
218 |
# 创建临时目录存储下载的图片
|
219 |
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
220 |
+
local_image_dir = os.path.join(user_dir, task_id, "tasks", "images")
|
221 |
+
|
222 |
os.makedirs(local_image_dir, exist_ok=True)
|
223 |
|
224 |
while max_time > 0:
|
|
|
228 |
# 定期检查后端状态
|
229 |
if current_time - last_status_check > status_check_interval:
|
230 |
status = get_task_status(task_id)
|
231 |
+
print(str(request.session_hash), "status: ", status)
|
232 |
if status.get("status") == "completed":
|
233 |
# 确保处理完所有已生成的图片
|
234 |
process_remaining_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
|
|
|
284 |
"""创建视频片段"""
|
285 |
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
286 |
os.makedirs(user_dir, exist_ok=True)
|
287 |
+
video_chunk_path = os.path.join(user_dir, "video_chunk")
|
288 |
os.makedirs(video_chunk_path, exist_ok=True)
|
289 |
segment_name = os.path.join(video_chunk_path, f"output_{uuid.uuid4()}.mp4")
|
290 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
|
373 |
)
|
374 |
return response.json()
|
375 |
except Exception as e:
|
376 |
+
return {"status": "error get_task_status", "message": str(e)}
|
377 |
|
378 |
def terminate_task(task_id: str) -> Optional[dict]:
|
379 |
"""
|
|
|
440 |
# 记录用户提交
|
441 |
user_ip = request.client.host if request else "unknown"
|
442 |
session_id = request.session_hash
|
443 |
+
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
444 |
|
445 |
if not is_request_allowed(user_ip):
|
446 |
log_submission(scene, prompt, model, max_step, user_ip, "IP blocked temporarily")
|
|
|
464 |
status = get_task_status(task_id)
|
465 |
print("first status: ", status)
|
466 |
result_folder = status.get("result", "")
|
467 |
+
result_folder = "gradio_demo/tasks/" + task_id
|
468 |
except Exception as e:
|
469 |
log_submission(scene, prompt, model, max_step, user_ip, str(e))
|
470 |
raise gr.Error(f"error occurred when parsing submission result from backend: {str(e)}")
|
471 |
|
472 |
|
473 |
+
# if not os.path.exists(result_folder):
|
474 |
+
# log_submission(scene, prompt, model, max_step, user_ip, "Result folder provided by backend doesn't exist")
|
475 |
+
# raise gr.Error(f"Result folder provided by backend doesn't exist: <PATH>{result_folder}")
|
476 |
|
477 |
|
478 |
# 流式输出视频片段
|
479 |
try:
|
480 |
+
for video_path in stream_simulation_results(result_folder, task_id, request):
|
481 |
if video_path:
|
482 |
yield video_path, history
|
483 |
except Exception as e:
|
|
|
488 |
status = get_task_status(task_id)
|
489 |
print("status: ", status)
|
490 |
if status.get("status") == "completed":
|
491 |
+
# time.sleep(3)
|
492 |
+
oss_video_path = os.path.join(result_folder, "manipulation.mp4")
|
493 |
+
local_video_path = os.path.join(user_dir, task_id, "tasks", "manipulation.mp4")
|
494 |
+
download_oss_file(oss_video_path, local_video_path)
|
495 |
+
print("oss_video_path: ", oss_video_path)
|
496 |
+
print("local_video_path: ", local_video_path)
|
497 |
+
|
498 |
+
video_path = convert_to_h264(local_video_path)
|
499 |
|
500 |
# 创建新的历史记录条目
|
501 |
new_entry = {
|