File size: 1,843 Bytes
3cecacc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from database import SessionLocal
from db_models.music import Music, AI_Detection_Music
from db_models.analysis import Music_Analysis, AnalysisStatus
from db_models.user import User 
from google.cloud import storage
import uuid
import os
import torch
from inference import run_inference, load_audio, download_from_gcs
from inference import get_model
device = 'cuda'
backbone_model, input_dim = get_model('MERT',device)

def do_web_inference(model, music_id):
    """
    μŒμ•… νŒŒμΌμ— λŒ€ν•΄ AI 생성 탐지 좔둠을 μ‹€ν–‰ν•©λ‹ˆλ‹€.
    
    Args:
        music_id: 뢄석할 μŒμ•…μ˜ ID
        task: μˆ˜ν–‰ν•  μž‘μ—… μœ ν˜• (κΈ°λ³Έκ°’: None)
    
    Returns:
        뢄석 κ²°κ³Όλ₯Ό ν¬ν•¨ν•˜λŠ” λ”•μ…”λ„ˆλ¦¬
    """
    try:
        # λ°μ΄ν„°λ² μ΄μŠ€μ—μ„œ μŒμ•… 정보 κ°€μ Έμ˜€κΈ°
        db = SessionLocal()
        music = db.query(Music).filter(Music.id == music_id).first()
        AI_Detection_Music = music.ai_detection_musics[0]

        print(music, music_id)
        print(AI_Detection_Music.id)
        
        if not music:
            return {"status": "error", "message": f"Music ID {music_id} not found"}
        
        # 파일 경둜 κ°€μ Έμ˜€κΈ°
        wav_path = music.music_path
        download_from_gcs('mippia-bucket', wav_path, wav_path)
        segments, padding_mask = load_audio(wav_path, sr=24000)
        segments = segments.to(device).to(torch.float32)    
        logits,embedding = backbone_model(segments.squeeze(1))
        embedding.to(device)
        # μΆ”λ‘  μ‹€ν–‰
        results = run_inference(model, embedding, padding_mask, device=device)
        
        # μž„μ‹œ 파일 μ‚­μ œ
        if os.path.exists(wav_path):
            os.remove(wav_path)
        print(results)
    finally:
        # λ°μ΄ν„°λ² μ΄μŠ€ μ„Έμ…˜ μ’…λ£Œ
        if 'db' in locals():
            db.close()