|
import logging |
|
import asyncio |
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import FileResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import deeplabcut as dlc |
|
import os |
|
import requests |
|
from typing import Dict |
|
import threading |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
project_path = "/app/kunin-dlc-240814" |
|
config_path = os.path.join(project_path, "config.yaml") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
lock = threading.Lock() |
|
|
|
|
|
downloaded_files = {} |
|
|
|
@app.post("/analyze") |
|
async def analyze_video(request: Request): |
|
logging.info("Received request for video analysis") |
|
|
|
|
|
with lock: |
|
try: |
|
|
|
clear_working_directory("/app/working") |
|
|
|
data: Dict = await request.json() |
|
logging.info(f"Request data: {data}") |
|
except Exception as e: |
|
logging.error(f"Failed to parse JSON: {str(e)}") |
|
raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") |
|
|
|
if 'videoUrl' not in data or 'videoOssId' not in data: |
|
logging.error("Missing videoUrl or videoOssId in request") |
|
raise HTTPException(status_code=400, detail="videoUrl and videoOssId are required") |
|
|
|
video_url = data['videoUrl'] |
|
video_oss_id = data['videoOssId'] |
|
|
|
try: |
|
video_path = download_video(video_url, video_oss_id) |
|
logging.info(f"Downloaded video to: {video_path}") |
|
except Exception as e: |
|
logging.error(f"Error downloading video: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error downloading video: {str(e)}") |
|
|
|
try: |
|
|
|
logging.info(f"Starting analysis for video: {video_path}") |
|
dlc.analyze_videos(config_path, [video_path], shuffle=0, videotype="mp4", auto_track=True) |
|
logging.info("Video analysis completed") |
|
|
|
|
|
logging.info(f"Filtering predictions for video: {video_path}") |
|
dlc.filterpredictions(config_path, [video_path], shuffle=0, videotype='mp4') |
|
logging.info("Predictions filtered") |
|
|
|
|
|
logging.info(f"Creating labeled video for: {video_path}") |
|
dlc.create_labeled_video( |
|
config_path, |
|
[video_path], |
|
videotype='mp4', |
|
shuffle=0, |
|
color_by="individual", |
|
keypoints_only=False, |
|
draw_skeleton=True, |
|
filtered=True, |
|
) |
|
logging.info("Labeled video created") |
|
|
|
|
|
labeled_video_path, h5_file_path = find_and_rename_output_files(video_path) |
|
if not labeled_video_path or not h5_file_path: |
|
logging.error("Output files missing after analysis") |
|
raise HTTPException(status_code=500, detail="Analysis completed, but output files are missing.") |
|
|
|
|
|
downloaded_files[os.path.basename(labeled_video_path)] = False |
|
downloaded_files[os.path.basename(h5_file_path)] = False |
|
|
|
except Exception as e: |
|
logging.error(f"Error during video analysis: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Error during video analysis: {str(e)}") |
|
|
|
response_data = { |
|
"videoOssId": video_oss_id, |
|
"labeled_video": f"/post_download/{os.path.basename(labeled_video_path)}", |
|
"h5_file": f"/post_download/{os.path.basename(h5_file_path)}" |
|
} |
|
|
|
logging.info("Returning response data") |
|
return response_data |
|
|
|
@app.post("/post_download/{filename}") |
|
async def post_download_file(filename: str): |
|
file_path = os.path.join("/app/working", filename) |
|
if os.path.exists(file_path): |
|
|
|
downloaded_files[filename] = True |
|
logging.info(f"Serving file: {file_path}") |
|
return FileResponse(path=file_path, media_type='application/octet-stream', filename=filename) |
|
else: |
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
|
async def wait_for_files_to_be_downloaded(files: list): |
|
"""等待文件被访问和下载后删除""" |
|
try: |
|
while any(not downloaded_files.get(os.path.basename(file), False) for file in files): |
|
logging.info(f"Waiting for files to be downloaded: {files}") |
|
await asyncio.sleep(5) |
|
logging.info("All files have been downloaded, deleting them...") |
|
for file in files: |
|
if os.path.exists(file): |
|
os.remove(file) |
|
logging.info(f"Deleted file: {file}") |
|
logging.info("All files have been deleted.") |
|
except Exception as e: |
|
logging.error(f"Failed to wait for files: {str(e)}") |
|
|
|
def find_and_rename_output_files(video_path: str): |
|
"""查找并重命名生成的标注视频和H5文件""" |
|
working_directory = "/app/working" |
|
base_name = os.path.splitext(os.path.basename(video_path))[0] |
|
|
|
labeled_video = None |
|
h5_file = None |
|
|
|
for file in os.listdir(working_directory): |
|
if file.endswith("_id_labeled.mp4"): |
|
labeled_video = os.path.join(working_directory, file) |
|
new_labeled_video = os.path.join(working_directory, f"{base_name}_labeled.mp4") |
|
os.rename(labeled_video, new_labeled_video) |
|
labeled_video = new_labeled_video |
|
logging.info(f"Renamed labeled video to: {labeled_video}") |
|
elif file.endswith("_filtered.h5"): |
|
h5_file = os.path.join(working_directory, file) |
|
new_h5_file = os.path.join(working_directory, f"{base_name}.h5") |
|
os.rename(h5_file, new_h5_file) |
|
h5_file = new_h5_file |
|
logging.info(f"Renamed H5 file to: {h5_file}") |
|
|
|
logging.info(f"Files in working directory after video processing: {os.listdir(working_directory)}") |
|
return labeled_video, h5_file |
|
|
|
def download_video(url: str, video_oss_id: str) -> str: |
|
working_directory = "/app/working" |
|
|
|
try: |
|
|
|
if not os.path.exists(working_directory): |
|
os.makedirs(working_directory) |
|
|
|
|
|
local_filename = os.path.join(working_directory, f"{video_oss_id}.mp4") |
|
|
|
|
|
logging.info(f"Downloading video from URL: {url}") |
|
with requests.get(url, stream=True, timeout=60) as r: |
|
r.raise_for_status() |
|
with open(local_filename, 'wb') as f: |
|
for chunk in r.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
logging.info(f"Video downloaded to: {local_filename}") |
|
return local_filename |
|
|
|
except requests.exceptions.RequestException as e: |
|
logging.error(f"Failed to download video: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Failed to download video: {str(e)}") |
|
except Exception as e: |
|
logging.error(f"Unexpected error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") |
|
|
|
def clear_working_directory(directory: str): |
|
"""清空工作目录中的所有文件""" |
|
try: |
|
for filename in os.listdir(directory): |
|
file_path = os.path.join(directory, filename) |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
os.rmdir(file_path) |
|
logging.info(f"Cleared all files in directory: {directory}") |
|
except Exception as e: |
|
logging.error(f"Failed to clear directory {directory}: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8080) |
|
|