aimusicdetection / web_inference.py
nininigold's picture
Upload folder using huggingface_hub
3cecacc verified
from database import SessionLocal
from db_models.music import Music, AI_Detection_Music
from db_models.analysis import Music_Analysis, AnalysisStatus
from db_models.user import User
from google.cloud import storage
import uuid
import os
import torch
from inference import run_inference, load_audio, download_from_gcs
from inference import get_model
device = 'cuda'
backbone_model, input_dim = get_model('MERT',device)
def do_web_inference(model, music_id):
"""
μŒμ•… νŒŒμΌμ— λŒ€ν•΄ AI 생성 탐지 좔둠을 μ‹€ν–‰ν•©λ‹ˆλ‹€.
Args:
music_id: 뢄석할 μŒμ•…μ˜ ID
task: μˆ˜ν–‰ν•  μž‘μ—… μœ ν˜• (κΈ°λ³Έκ°’: None)
Returns:
뢄석 κ²°κ³Όλ₯Ό ν¬ν•¨ν•˜λŠ” λ”•μ…”λ„ˆλ¦¬
"""
try:
# λ°μ΄ν„°λ² μ΄μŠ€μ—μ„œ μŒμ•… 정보 κ°€μ Έμ˜€κΈ°
db = SessionLocal()
music = db.query(Music).filter(Music.id == music_id).first()
AI_Detection_Music = music.ai_detection_musics[0]
print(music, music_id)
print(AI_Detection_Music.id)
if not music:
return {"status": "error", "message": f"Music ID {music_id} not found"}
# 파일 경둜 κ°€μ Έμ˜€κΈ°
wav_path = music.music_path
download_from_gcs('mippia-bucket', wav_path, wav_path)
segments, padding_mask = load_audio(wav_path, sr=24000)
segments = segments.to(device).to(torch.float32)
logits,embedding = backbone_model(segments.squeeze(1))
embedding.to(device)
# μΆ”λ‘  μ‹€ν–‰
results = run_inference(model, embedding, padding_mask, device=device)
# μž„μ‹œ 파일 μ‚­μ œ
if os.path.exists(wav_path):
os.remove(wav_path)
print(results)
finally:
# λ°μ΄ν„°λ² μ΄μŠ€ μ„Έμ…˜ μ’…λ£Œ
if 'db' in locals():
db.close()