Spaces:
Runtime error
Runtime error
from database import SessionLocal | |
from db_models.music import Music, AI_Detection_Music | |
from db_models.analysis import Music_Analysis, AnalysisStatus | |
from db_models.user import User | |
from google.cloud import storage | |
import uuid | |
import os | |
import torch | |
from inference import run_inference, load_audio, download_from_gcs | |
from inference import get_model | |
device = 'cuda' | |
backbone_model, input_dim = get_model('MERT',device) | |
def do_web_inference(model, music_id): | |
""" | |
μμ νμΌμ λν΄ AI μμ± νμ§ μΆλ‘ μ μ€νν©λλ€. | |
Args: | |
music_id: λΆμν μμ μ ID | |
task: μνν μμ μ ν (κΈ°λ³Έκ°: None) | |
Returns: | |
λΆμ κ²°κ³Όλ₯Ό ν¬ν¨νλ λμ λ리 | |
""" | |
try: | |
# λ°μ΄ν°λ² μ΄μ€μμ μμ μ 보 κ°μ Έμ€κΈ° | |
db = SessionLocal() | |
music = db.query(Music).filter(Music.id == music_id).first() | |
AI_Detection_Music = music.ai_detection_musics[0] | |
print(music, music_id) | |
print(AI_Detection_Music.id) | |
if not music: | |
return {"status": "error", "message": f"Music ID {music_id} not found"} | |
# νμΌ κ²½λ‘ κ°μ Έμ€κΈ° | |
wav_path = music.music_path | |
download_from_gcs('mippia-bucket', wav_path, wav_path) | |
segments, padding_mask = load_audio(wav_path, sr=24000) | |
segments = segments.to(device).to(torch.float32) | |
logits,embedding = backbone_model(segments.squeeze(1)) | |
embedding.to(device) | |
# μΆλ‘ μ€ν | |
results = run_inference(model, embedding, padding_mask, device=device) | |
# μμ νμΌ μμ | |
if os.path.exists(wav_path): | |
os.remove(wav_path) | |
print(results) | |
finally: | |
# λ°μ΄ν°λ² μ΄μ€ μΈμ μ’ λ£ | |
if 'db' in locals(): | |
db.close() | |