liuyiyang01 commited on
Commit
1724445
·
1 Parent(s): ffe3930

dev streaming with oss

Browse files
Files changed (1) hide show
  1. app.py +101 -54
app.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
  from datetime import datetime, timedelta
12
  from collections import defaultdict
13
  import shutil
 
14
 
15
  # os.environ["SPACES_QUEUE_ENABLED"] = "true"
16
 
@@ -159,23 +160,57 @@ def format_logs_for_display(logs: list) -> str:
159
 
160
 
161
  ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
165
  """
166
- 流式输出仿真结果,同时监控图片文件夹和后端任务状态
167
 
168
  参数:
169
- result_folder: 包含生成图片的文件夹路径
170
  task_id: 后端任务ID用于状态查询
171
  fps: 输出视频的帧率
 
172
 
173
  生成:
174
  生成的视频文件路径 (分段输出)
175
  """
176
  # 初始化变量
177
- result_folder = os.path.join(result_folder, "image")
178
- os.makedirs(result_folder, exist_ok=True)
179
  frame_buffer: List[np.ndarray] = []
180
  frames_per_segment = fps * 2 # 每2秒60帧
181
  processed_files = set()
@@ -184,6 +219,11 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
184
  status_check_interval = 5 # 每5秒检查一次后端状态
185
  max_time = 240
186
 
 
 
 
 
 
187
  while max_time > 0:
188
  max_time -= 1
189
  current_time = time.time()
@@ -194,9 +234,9 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
194
  print("status: ", status)
195
  if status.get("status") == "completed":
196
  # 确保处理完所有已生成的图片
197
- process_remaining_images(result_folder, processed_files, frame_buffer)
198
  if frame_buffer:
199
- yield create_video_segment(frame_buffer, fps, width, height)
200
  break
201
  elif status.get("status") == "failed":
202
  raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
@@ -204,35 +244,40 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
204
  break
205
  last_status_check = current_time
206
 
207
- # 处理新生成的图片
208
- current_files = sorted(
209
- [f for f in os.listdir(result_folder)
210
- if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
211
- key=lambda x: os.path.splitext(x)[0] # 按文件名排序
212
- )
213
-
214
- new_files = [f for f in current_files if f not in processed_files]
215
- has_new_frames = False
216
-
217
- for filename in new_files:
218
- try:
219
- img_path = os.path.join(result_folder, filename)
220
- frame = cv2.imread(img_path)
221
- if frame is not None:
222
- if width == 0: # 第一次获取图像尺寸
223
- height, width = frame.shape[:2]
224
-
225
- frame_buffer.append(frame)
226
- processed_files.add(filename)
227
- has_new_frames = True
228
- except Exception as e:
229
- print(f"Error processing {filename}: {e}")
 
 
 
 
 
 
 
 
230
 
231
- # 如果有新帧且积累够60帧,输出视频片段
232
- if has_new_frames and len(frame_buffer) >= frames_per_segment:
233
- segment_frames = frame_buffer[:frames_per_segment]
234
- frame_buffer = frame_buffer[frames_per_segment:]
235
- yield create_video_segment(segment_frames, fps, width, height)
236
 
237
  time.sleep(1) # 避免过于频繁检查
238
 
@@ -255,25 +300,27 @@ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height:
255
 
256
  return segment_name
257
 
258
- def process_remaining_images(result_folder: str, processed_files: set, frame_buffer: List[np.ndarray]):
259
- """处理剩余的图片"""
260
- current_files = sorted(
261
- [f for f in os.listdir(result_folder)
262
- if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
263
- key=lambda x: os.path.splitext(x)[0]
264
- )
265
-
266
- new_files = [f for f in current_files if f not in processed_files]
267
-
268
- for filename in new_files:
269
- try:
270
- img_path = os.path.join(result_folder, filename)
271
- frame = cv2.imread(img_path)
272
- if frame is not None:
273
- frame_buffer.append(frame)
274
- processed_files.add(filename)
275
- except Exception as e:
276
- print(f"Error processing remaining {filename}: {e}")
 
 
277
 
278
 
279
 
 
11
  from datetime import datetime, timedelta
12
  from collections import defaultdict
13
  import shutil
14
+ from urllib.parse import urljoin
15
 
16
  # os.environ["SPACES_QUEUE_ENABLED"] = "true"
17
 
 
160
 
161
 
162
  ###############################################################################
163
+ def list_public_oss_files(base_url: str) -> List[str]:
164
+ """列出公共OSS文件夹中的所有图片文件"""
165
+ # 注意:这需要OSS支持目录列表功能,或者你有预先知道的文件命名规则
166
+ # 如果OSS不支持目录列表,可能需要后端API提供文件列表
167
+ # 这里假设可以直接通过HTTP访问
168
+
169
+ # 实际情况可能需要根据你的OSS具体配置调整
170
+ # 这里只是一个示例实现
171
+ try:
172
+ response = requests.get(base_url)
173
+ if response.status_code == 200:
174
+ # 这里需要根据OSS返回的实际内容解析文件列表
175
+ # 可能需要使用HTML解析器或正则表达式
176
+ # 以下只是示例
177
+ import re
178
+ files = re.findall(r'href="([^"]+\.(?:jpg|png|jpeg))"', response.text)
179
+ return sorted([urljoin(base_url, f) for f in files])
180
+ return []
181
+ except Exception as e:
182
+ print(f"Error listing public OSS files: {e}")
183
+ return []
184
 
185
+ def download_public_file(url: str, local_path: str):
186
+ """下载公开可访问的文件"""
187
+ try:
188
+ response = requests.get(url, stream=True)
189
+ if response.status_code == 200:
190
+ with open(local_path, 'wb') as f:
191
+ for chunk in response.iter_content(1024):
192
+ f.write(chunk)
193
+ return True
194
+ return False
195
+ except Exception as e:
196
+ print(f"Error downloading public file {url}: {e}")
197
+ return False
198
 
199
+ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30, request: gr.Request):
200
  """
201
+ 流式输出仿真结果,从公共OSS读取图片
202
 
203
  参数:
204
+ result_folder: OSS上包含生成图片的文件夹URL (从后端API返回)
205
  task_id: 后端任务ID用于状态查询
206
  fps: 输出视频的帧率
207
+ request: Gradio请求对象
208
 
209
  生成:
210
  生成的视频文件路径 (分段输出)
211
  """
212
  # 初始化变量
213
+ image_folder = urljoin(result_folder, "image/") # 确保以/结尾
 
214
  frame_buffer: List[np.ndarray] = []
215
  frames_per_segment = fps * 2 # 每2秒60帧
216
  processed_files = set()
 
219
  status_check_interval = 5 # 每5秒检查一次后端状态
220
  max_time = 240
221
 
222
+ # 创建临时目录存储下载的图片
223
+ user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
224
+ local_image_dir = os.path.join(user_dir, "tasks", "images")
225
+ os.makedirs(local_image_dir, exist_ok=True)
226
+
227
  while max_time > 0:
228
  max_time -= 1
229
  current_time = time.time()
 
234
  print("status: ", status)
235
  if status.get("status") == "completed":
236
  # 确保处理完所有已生成的图片
237
+ process_remaining_public_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
238
  if frame_buffer:
239
+ yield create_video_segment(frame_buffer, fps, width, height, request)
240
  break
241
  elif status.get("status") == "failed":
242
  raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
 
244
  break
245
  last_status_check = current_time
246
 
247
+ # 从公共OSS获取文件列表
248
+ try:
249
+ # 注意:这里假设可以直接列出OSS文件
250
+ # 如果不行,可能需要后端API提供文件列表
251
+ oss_files = list_public_oss_files(image_folder)
252
+ new_files = [f for f in oss_files if f not in processed_files]
253
+ has_new_frames = False
254
+
255
+ for file_url in new_files:
256
+ try:
257
+ # 下载文件到本地
258
+ filename = os.path.basename(file_url)
259
+ local_path = os.path.join(local_image_dir, filename)
260
+ if download_public_file(file_url, local_path):
261
+ # 读取图片
262
+ frame = cv2.imread(local_path)
263
+ if frame is not None:
264
+ if width == 0: # 第一次获取图像尺寸
265
+ height, width = frame.shape[:2]
266
+
267
+ frame_buffer.append(frame)
268
+ processed_files.add(file_url)
269
+ has_new_frames = True
270
+ except Exception as e:
271
+ print(f"Error processing {file_url}: {e}")
272
+
273
+ # 如果有新帧且积累够60帧,输出视频片段
274
+ if has_new_frames and len(frame_buffer) >= frames_per_segment:
275
+ segment_frames = frame_buffer[:frames_per_segment]
276
+ frame_buffer = frame_buffer[frames_per_segment:]
277
+ yield create_video_segment(segment_frames, fps, width, height, request)
278
 
279
+ except Exception as e:
280
+ print(f"Error accessing public OSS: {e}")
 
 
 
281
 
282
  time.sleep(1) # 避免过于频繁检查
283
 
 
300
 
301
  return segment_name
302
 
303
+ def process_remaining_public_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
304
+ """处理公共OSS上剩余的图片"""
305
+ try:
306
+ oss_files = list_public_oss_files(oss_folder)
307
+ new_files = [f for f in oss_files if f not in processed_files]
308
+
309
+ for file_url in new_files:
310
+ try:
311
+ # 下载文件到本地
312
+ filename = os.path.basename(file_url)
313
+ local_path = os.path.join(local_dir, filename)
314
+ if download_public_file(file_url, local_path):
315
+ # 读取图片
316
+ frame = cv2.imread(local_path)
317
+ if frame is not None:
318
+ frame_buffer.append(frame)
319
+ processed_files.add(file_url)
320
+ except Exception as e:
321
+ print(f"Error processing remaining {file_url}: {e}")
322
+ except Exception as e:
323
+ print(f"Error accessing public OSS for remaining files: {e}")
324
 
325
 
326