import logging import asyncio from fastapi import FastAPI, HTTPException, Request from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware import deeplabcut as dlc import os import requests from typing import Dict import threading # 实例化FastAPI应用 app = FastAPI() # 允许任何来源的CORS请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许任何来源 allow_credentials=True, allow_methods=["*"], # 允许所有方法 allow_headers=["*"], # 允许所有头部 ) # 配置文件路径 project_path = "/app/kunin-dlc-240814" config_path = os.path.join(project_path, "config.yaml") # 设置日志记录格式 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # 创建线程锁 lock = threading.Lock() # 用于跟踪已下载的文件 downloaded_files = {} @app.post("/analyze") async def analyze_video(request: Request): logging.info("Received request for video analysis") # 获取锁,确保只有一个分析任务在进行 with lock: try: # 在开始处理前清空工作目录 clear_working_directory("/app/working") data: Dict = await request.json() logging.info(f"Request data: {data}") except Exception as e: logging.error(f"Failed to parse JSON: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") if 'videoUrl' not in data or 'videoOssId' not in data: logging.error("Missing videoUrl or videoOssId in request") raise HTTPException(status_code=400, detail="videoUrl and videoOssId are required") video_url = data['videoUrl'] video_oss_id = data['videoOssId'] try: video_path = download_video(video_url, video_oss_id) logging.info(f"Downloaded video to: {video_path}") except Exception as e: logging.error(f"Error downloading video: {str(e)}") raise HTTPException(status_code=500, detail=f"Error downloading video: {str(e)}") try: # 执行推理 logging.info(f"Starting analysis for video: {video_path}") dlc.analyze_videos(config_path, [video_path], shuffle=0, videotype="mp4", auto_track=True) logging.info("Video analysis completed") # 过滤预测结果 logging.info(f"Filtering predictions for video: {video_path}") dlc.filterpredictions(config_path, [video_path], shuffle=0, videotype='mp4') logging.info("Predictions filtered") # 创建标注视频 logging.info(f"Creating labeled video for: {video_path}") dlc.create_labeled_video( config_path, [video_path], videotype='mp4', shuffle=0, color_by="individual", keypoints_only=False, draw_skeleton=True, filtered=True, ) logging.info("Labeled video created") # 查找并重命名输出文件 labeled_video_path, h5_file_path = find_and_rename_output_files(video_path) if not labeled_video_path or not h5_file_path: logging.error("Output files missing after analysis") raise HTTPException(status_code=500, detail="Analysis completed, but output files are missing.") # 初始化文件下载状态 downloaded_files[os.path.basename(labeled_video_path)] = False downloaded_files[os.path.basename(h5_file_path)] = False except Exception as e: logging.error(f"Error during video analysis: {str(e)}") raise HTTPException(status_code=500, detail=f"Error during video analysis: {str(e)}") response_data = { "videoOssId": video_oss_id, "labeled_video": f"/post_download/{os.path.basename(labeled_video_path)}", "h5_file": f"/post_download/{os.path.basename(h5_file_path)}" } logging.info("Returning response data") return response_data @app.post("/post_download/{filename}") async def post_download_file(filename: str): file_path = os.path.join("/app/working", filename) if os.path.exists(file_path): # 标记文件已被下载 downloaded_files[filename] = True logging.info(f"Serving file: {file_path}") return FileResponse(path=file_path, media_type='application/octet-stream', filename=filename) else: raise HTTPException(status_code=404, detail="File not found") async def wait_for_files_to_be_downloaded(files: list): """等待文件被访问和下载后删除""" try: while any(not downloaded_files.get(os.path.basename(file), False) for file in files): logging.info(f"Waiting for files to be downloaded: {files}") await asyncio.sleep(5) logging.info("All files have been downloaded, deleting them...") for file in files: if os.path.exists(file): os.remove(file) logging.info(f"Deleted file: {file}") logging.info("All files have been deleted.") except Exception as e: logging.error(f"Failed to wait for files: {str(e)}") def find_and_rename_output_files(video_path: str): """查找并重命名生成的标注视频和H5文件""" working_directory = "/app/working" base_name = os.path.splitext(os.path.basename(video_path))[0] labeled_video = None h5_file = None for file in os.listdir(working_directory): if file.endswith("_id_labeled.mp4"): labeled_video = os.path.join(working_directory, file) new_labeled_video = os.path.join(working_directory, f"{base_name}_labeled.mp4") os.rename(labeled_video, new_labeled_video) labeled_video = new_labeled_video logging.info(f"Renamed labeled video to: {labeled_video}") elif file.endswith("_filtered.h5"): h5_file = os.path.join(working_directory, file) new_h5_file = os.path.join(working_directory, f"{base_name}.h5") os.rename(h5_file, new_h5_file) h5_file = new_h5_file logging.info(f"Renamed H5 file to: {h5_file}") logging.info(f"Files in working directory after video processing: {os.listdir(working_directory)}") return labeled_video, h5_file def download_video(url: str, video_oss_id: str) -> str: working_directory = "/app/working" try: # 确保目标目录存在 if not os.path.exists(working_directory): os.makedirs(working_directory) # 使用video_oss_id作为文件名,避免命名冲突 local_filename = os.path.join(working_directory, f"{video_oss_id}.mp4") # 下载视频并处理可能的连接错误 logging.info(f"Downloading video from URL: {url}") with requests.get(url, stream=True, timeout=60) as r: r.raise_for_status() with open(local_filename, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): if chunk: # 过滤掉保持连接的空块 f.write(chunk) logging.info(f"Video downloaded to: {local_filename}") return local_filename except requests.exceptions.RequestException as e: logging.error(f"Failed to download video: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to download video: {str(e)}") except Exception as e: logging.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") def clear_working_directory(directory: str): """清空工作目录中的所有文件""" try: for filename in os.listdir(directory): file_path = os.path.join(directory, filename) if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): os.rmdir(file_path) logging.info(f"Cleared all files in directory: {directory}") except Exception as e: logging.error(f"Failed to clear directory {directory}: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8080)