liuyiyang01
commited on
Commit
·
1724445
1
Parent(s):
ffe3930
dev streaming with oss
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import numpy as np
|
|
11 |
from datetime import datetime, timedelta
|
12 |
from collections import defaultdict
|
13 |
import shutil
|
|
|
14 |
|
15 |
# os.environ["SPACES_QUEUE_ENABLED"] = "true"
|
16 |
|
@@ -159,23 +160,57 @@ def format_logs_for_display(logs: list) -> str:
|
|
159 |
|
160 |
|
161 |
###############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
|
165 |
"""
|
166 |
-
|
167 |
|
168 |
参数:
|
169 |
-
result_folder:
|
170 |
task_id: 后端任务ID用于状态查询
|
171 |
fps: 输出视频的帧率
|
|
|
172 |
|
173 |
生成:
|
174 |
生成的视频文件路径 (分段输出)
|
175 |
"""
|
176 |
# 初始化变量
|
177 |
-
|
178 |
-
os.makedirs(result_folder, exist_ok=True)
|
179 |
frame_buffer: List[np.ndarray] = []
|
180 |
frames_per_segment = fps * 2 # 每2秒60帧
|
181 |
processed_files = set()
|
@@ -184,6 +219,11 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
|
|
184 |
status_check_interval = 5 # 每5秒检查一次后端状态
|
185 |
max_time = 240
|
186 |
|
|
|
|
|
|
|
|
|
|
|
187 |
while max_time > 0:
|
188 |
max_time -= 1
|
189 |
current_time = time.time()
|
@@ -194,9 +234,9 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
|
|
194 |
print("status: ", status)
|
195 |
if status.get("status") == "completed":
|
196 |
# 确保处理完所有已生成的图片
|
197 |
-
|
198 |
if frame_buffer:
|
199 |
-
yield create_video_segment(frame_buffer, fps, width, height)
|
200 |
break
|
201 |
elif status.get("status") == "failed":
|
202 |
raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
|
@@ -204,35 +244,40 @@ def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30):
|
|
204 |
break
|
205 |
last_status_check = current_time
|
206 |
|
207 |
-
#
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
|
232 |
-
|
233 |
-
segment_frames = frame_buffer[:frames_per_segment]
|
234 |
-
frame_buffer = frame_buffer[frames_per_segment:]
|
235 |
-
yield create_video_segment(segment_frames, fps, width, height)
|
236 |
|
237 |
time.sleep(1) # 避免过于频繁检查
|
238 |
|
@@ -255,25 +300,27 @@ def create_video_segment(frames: List[np.ndarray], fps: int, width: int, height:
|
|
255 |
|
256 |
return segment_name
|
257 |
|
258 |
-
def
|
259 |
-
"""
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
|
|
277 |
|
278 |
|
279 |
|
|
|
11 |
from datetime import datetime, timedelta
|
12 |
from collections import defaultdict
|
13 |
import shutil
|
14 |
+
from urllib.parse import urljoin
|
15 |
|
16 |
# os.environ["SPACES_QUEUE_ENABLED"] = "true"
|
17 |
|
|
|
160 |
|
161 |
|
162 |
###############################################################################
|
163 |
+
def list_public_oss_files(base_url: str) -> List[str]:
|
164 |
+
"""列出公共OSS文件夹中的所有图片文件"""
|
165 |
+
# 注意:这需要OSS支持目录列表功能,或者你有预先知道的文件命名规则
|
166 |
+
# 如果OSS不支持目录列表,可能需要后端API提供文件列表
|
167 |
+
# 这里假设可以直接通过HTTP访问
|
168 |
+
|
169 |
+
# 实际情况可能需要根据你的OSS具体配置调整
|
170 |
+
# 这里只是一个示例实现
|
171 |
+
try:
|
172 |
+
response = requests.get(base_url)
|
173 |
+
if response.status_code == 200:
|
174 |
+
# 这里需要根据OSS返回的实际内容解析文件列表
|
175 |
+
# 可能需要使用HTML解析器或正则表达式
|
176 |
+
# 以下只是示例
|
177 |
+
import re
|
178 |
+
files = re.findall(r'href="([^"]+\.(?:jpg|png|jpeg))"', response.text)
|
179 |
+
return sorted([urljoin(base_url, f) for f in files])
|
180 |
+
return []
|
181 |
+
except Exception as e:
|
182 |
+
print(f"Error listing public OSS files: {e}")
|
183 |
+
return []
|
184 |
|
185 |
+
def download_public_file(url: str, local_path: str):
|
186 |
+
"""下载公开可访问的文件"""
|
187 |
+
try:
|
188 |
+
response = requests.get(url, stream=True)
|
189 |
+
if response.status_code == 200:
|
190 |
+
with open(local_path, 'wb') as f:
|
191 |
+
for chunk in response.iter_content(1024):
|
192 |
+
f.write(chunk)
|
193 |
+
return True
|
194 |
+
return False
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error downloading public file {url}: {e}")
|
197 |
+
return False
|
198 |
|
199 |
+
def stream_simulation_results(result_folder: str, task_id: str, fps: int = 30, request: gr.Request):
|
200 |
"""
|
201 |
+
流式输出仿真结果,从公共OSS读取图片
|
202 |
|
203 |
参数:
|
204 |
+
result_folder: OSS上包含生成图片的文件夹URL (从后端API返回)
|
205 |
task_id: 后端任务ID用于状态查询
|
206 |
fps: 输出视频的帧率
|
207 |
+
request: Gradio请求对象
|
208 |
|
209 |
生成:
|
210 |
生成的视频文件路径 (分段输出)
|
211 |
"""
|
212 |
# 初始化变量
|
213 |
+
image_folder = urljoin(result_folder, "image/") # 确保以/结尾
|
|
|
214 |
frame_buffer: List[np.ndarray] = []
|
215 |
frames_per_segment = fps * 2 # 每2秒60帧
|
216 |
processed_files = set()
|
|
|
219 |
status_check_interval = 5 # 每5秒检查一次后端状态
|
220 |
max_time = 240
|
221 |
|
222 |
+
# 创建临时目录存储下载的图片
|
223 |
+
user_dir = os.path.join(TMP_ROOT, str(request.session_hash))
|
224 |
+
local_image_dir = os.path.join(user_dir, "tasks", "images")
|
225 |
+
os.makedirs(local_image_dir, exist_ok=True)
|
226 |
+
|
227 |
while max_time > 0:
|
228 |
max_time -= 1
|
229 |
current_time = time.time()
|
|
|
234 |
print("status: ", status)
|
235 |
if status.get("status") == "completed":
|
236 |
# 确保处理完所有已生成的图片
|
237 |
+
process_remaining_public_oss_images(image_folder, local_image_dir, processed_files, frame_buffer)
|
238 |
if frame_buffer:
|
239 |
+
yield create_video_segment(frame_buffer, fps, width, height, request)
|
240 |
break
|
241 |
elif status.get("status") == "failed":
|
242 |
raise gr.Error(f"任务执行失败: {status.get('result', '未知错误')}")
|
|
|
244 |
break
|
245 |
last_status_check = current_time
|
246 |
|
247 |
+
# 从公共OSS获取文件列表
|
248 |
+
try:
|
249 |
+
# 注意:这里假设可以直接列出OSS文件
|
250 |
+
# 如果不行,可能需要后端API提供文件列表
|
251 |
+
oss_files = list_public_oss_files(image_folder)
|
252 |
+
new_files = [f for f in oss_files if f not in processed_files]
|
253 |
+
has_new_frames = False
|
254 |
+
|
255 |
+
for file_url in new_files:
|
256 |
+
try:
|
257 |
+
# 下载文件到本地
|
258 |
+
filename = os.path.basename(file_url)
|
259 |
+
local_path = os.path.join(local_image_dir, filename)
|
260 |
+
if download_public_file(file_url, local_path):
|
261 |
+
# 读取图片
|
262 |
+
frame = cv2.imread(local_path)
|
263 |
+
if frame is not None:
|
264 |
+
if width == 0: # 第一次获取图像尺寸
|
265 |
+
height, width = frame.shape[:2]
|
266 |
+
|
267 |
+
frame_buffer.append(frame)
|
268 |
+
processed_files.add(file_url)
|
269 |
+
has_new_frames = True
|
270 |
+
except Exception as e:
|
271 |
+
print(f"Error processing {file_url}: {e}")
|
272 |
+
|
273 |
+
# 如果有新帧且积累够60帧,输出视频片段
|
274 |
+
if has_new_frames and len(frame_buffer) >= frames_per_segment:
|
275 |
+
segment_frames = frame_buffer[:frames_per_segment]
|
276 |
+
frame_buffer = frame_buffer[frames_per_segment:]
|
277 |
+
yield create_video_segment(segment_frames, fps, width, height, request)
|
278 |
|
279 |
+
except Exception as e:
|
280 |
+
print(f"Error accessing public OSS: {e}")
|
|
|
|
|
|
|
281 |
|
282 |
time.sleep(1) # 避免过于频繁检查
|
283 |
|
|
|
300 |
|
301 |
return segment_name
|
302 |
|
303 |
+
def process_remaining_public_oss_images(oss_folder: str, local_dir: str, processed_files: set, frame_buffer: List[np.ndarray]):
|
304 |
+
"""处理公共OSS上剩余的图片"""
|
305 |
+
try:
|
306 |
+
oss_files = list_public_oss_files(oss_folder)
|
307 |
+
new_files = [f for f in oss_files if f not in processed_files]
|
308 |
+
|
309 |
+
for file_url in new_files:
|
310 |
+
try:
|
311 |
+
# 下载文件到本地
|
312 |
+
filename = os.path.basename(file_url)
|
313 |
+
local_path = os.path.join(local_dir, filename)
|
314 |
+
if download_public_file(file_url, local_path):
|
315 |
+
# 读取图片
|
316 |
+
frame = cv2.imread(local_path)
|
317 |
+
if frame is not None:
|
318 |
+
frame_buffer.append(frame)
|
319 |
+
processed_files.add(file_url)
|
320 |
+
except Exception as e:
|
321 |
+
print(f"Error processing remaining {file_url}: {e}")
|
322 |
+
except Exception as e:
|
323 |
+
print(f"Error accessing public OSS for remaining files: {e}")
|
324 |
|
325 |
|
326 |
|