Spaces:
Runtime error
Runtime error
File size: 1,843 Bytes
3cecacc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from database import SessionLocal
from db_models.music import Music, AI_Detection_Music
from db_models.analysis import Music_Analysis, AnalysisStatus
from db_models.user import User
from google.cloud import storage
import uuid
import os
import torch
from inference import run_inference, load_audio, download_from_gcs
from inference import get_model
device = 'cuda'
backbone_model, input_dim = get_model('MERT',device)
def do_web_inference(model, music_id):
"""
μμ
νμΌμ λν΄ AI μμ± νμ§ μΆλ‘ μ μ€νν©λλ€.
Args:
music_id: λΆμν μμ
μ ID
task: μνν μμ
μ ν (κΈ°λ³Έκ°: None)
Returns:
λΆμ κ²°κ³Όλ₯Ό ν¬ν¨νλ λμ
λ리
"""
try:
# λ°μ΄ν°λ² μ΄μ€μμ μμ
μ 보 κ°μ Έμ€κΈ°
db = SessionLocal()
music = db.query(Music).filter(Music.id == music_id).first()
AI_Detection_Music = music.ai_detection_musics[0]
print(music, music_id)
print(AI_Detection_Music.id)
if not music:
return {"status": "error", "message": f"Music ID {music_id} not found"}
# νμΌ κ²½λ‘ κ°μ Έμ€κΈ°
wav_path = music.music_path
download_from_gcs('mippia-bucket', wav_path, wav_path)
segments, padding_mask = load_audio(wav_path, sr=24000)
segments = segments.to(device).to(torch.float32)
logits,embedding = backbone_model(segments.squeeze(1))
embedding.to(device)
# μΆλ‘ μ€ν
results = run_inference(model, embedding, padding_mask, device=device)
# μμ νμΌ μμ
if os.path.exists(wav_path):
os.remove(wav_path)
print(results)
finally:
# λ°μ΄ν°λ² μ΄μ€ μΈμ
μ’
λ£
if 'db' in locals():
db.close()
|