aimusicdetection / inference.py
nininigold's picture
Upload folder using huggingface_hub
3cecacc verified
raw
history blame
12 kB
import os
from pathlib import Path
import json
from beat_this.inference import File2Beats
import numpy as np
import torch
from typing import List, Tuple, Optional
import pytorch_lightning as pl
from model import MusicClassifier, MusicAudioClassifier
import argparse
import torch
import torchaudio
import deepspeed
import scipy.signal as signal
from typing import Dict, List
from google.cloud import storage
from dataset_f import FakeMusicCapsDataset
from preprocess import get_segments_from_wav, find_optimal_segment_length
#not for ismir
def download_from_gcs(bucket_name, source_blob_name, destination_file_name):
destination_dir = os.path.dirname(destination_file_name)
if not os.path.exists(destination_dir):
os.makedirs(destination_dir)
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
def highpass_filter(y, sr, cutoff=1000, order=5):
if isinstance(sr, np.ndarray):
sr = np.mean(sr)
if not isinstance(sr, (int, float)):
raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}")
nyquist = 0.5 * sr
if cutoff <= 0 or cutoff >= nyquist:
cutoff = max(10, min(cutoff, nyquist - 1))
normal_cutoff = cutoff / nyquist
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
y_filtered = signal.lfilter(b, a, y)
return y_filtered
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
"""
์˜ค๋””์˜ค ํŒŒ์ผ์„ ๋ถˆ๋Ÿฌ์™€ ์„ธ๊ทธ๋จผํŠธ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
๊ณ ์ •๋œ ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ๋ฅผ ์ตœ๋Œ€ 48๊ฐœ ์ถ”์ถœํ•˜๊ณ , ๋ถ€์กฑํ•œ ๊ฒฝ์šฐ ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
Args:
audio_path: ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
sr: ๋ชฉํ‘œ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ (๊ธฐ๋ณธ๊ฐ’ 24000)
Returns:
Tuple containing:
- ์˜ค๋””์˜ค ํŒŒํ˜•์ด ๋‹ด๊ธด ํ…์„œ (48, 1, 240000)
- ํŒจ๋”ฉ ๋งˆ์Šคํฌ ํ…์„œ (48), True = ํŒจ๋”ฉ, False = ์‹ค์ œ ์˜ค๋””์˜ค
"""
beats, downbeats = get_segments_from_wav(audio_path)
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats)
waveform, sample_rate = torchaudio.load(audio_path)
# ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ float32๋กœ ๋ณ€ํ™˜
waveform = waveform.to(torch.float32)
if sample_rate != sr:
resampler = torchaudio.transforms.Resample(sample_rate, sr)
waveform = resampler(waveform)
# ๋ชจ๋…ธ๋กœ ๋ณ€ํ™˜ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 240000 ์ƒ˜ํ”Œ = 10์ดˆ @ 24kHz
fixed_samples = 240000
# ๊ฐ downbeat์—์„œ ์‹œ์ž‘ํ•˜๋Š” segment ์ƒ์„ฑ
segments = []
# ๋‹ค์šด๋น„ํŠธ๊ฐ€ ์—†๊ฑฐ๋‚˜ ๋งค์šฐ ์ ์„ ๊ฒฝ์šฐ ์ „์ฒด ์˜ค๋””์˜ค๋ฅผ ์ผ์ • ๊ฐ„๊ฒฉ์œผ๋กœ ๋ถ„ํ• 
if len(cleaned_downbeats) < 2:
# ์˜ค๋””์˜ค ์ด ๊ธธ์ด (์ดˆ)
total_duration = waveform.size(1) / sr
# 5์ดˆ ๊ฐ„๊ฒฉ์œผ๋กœ ์„ธ๊ทธ๋จผํŠธ ์‹œ์ž‘์  ์ƒ์„ฑ (๋˜๋Š” ๋” ์งง์€ ๊ฐ„๊ฒฉ์œผ๋กœ ์„ค์ • ๊ฐ€๋Šฅ)
# 240000 ์ƒ˜ํ”Œ์€ 10์ดˆ์ด๋ฏ€๋กœ 5์ดˆ ๊ฐ„๊ฒฉ์€ ์„ธ๊ทธ๋จผํŠธ ๊ฐ„ 50% ์˜ค๋ฒ„๋žฉ
segment_interval = 5.0 # ์ดˆ ๋‹จ์œ„
# ์‹œ์ž‘ ์‹œ๊ฐ„ ๋ชฉ๋ก ์ƒ์„ฑ (0์ดˆ๋ถ€ํ„ฐ ์‹œ์ž‘)
start_times = [t for t in np.arange(0, total_duration - (fixed_samples/sr) + 0.01, segment_interval)]
# ์ตœ์†Œํ•œ ํ•˜๋‚˜์˜ ์„ธ๊ทธ๋จผํŠธ๋Š” ๋ณด์žฅ
if not start_times and total_duration > 0:
start_times = [0.0]
else:
# ๊ธฐ์กด ๋ฐฉ์‹๋Œ€๋กœ ๋‹ค์šด๋น„ํŠธ ์‚ฌ์šฉ
start_times = cleaned_downbeats
# ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
for i, start_time in enumerate(start_times):
# ์‹œ์ž‘ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ
start_sample = int(start_time * sr)
# ๋ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ (์‹œ์ž‘ ์ง€์  + ๊ณ ์ • ๊ธธ์ด)
end_sample = start_sample + fixed_samples
# ํŒŒ์ผ ๋์„ ๋„˜์–ด๊ฐ€๋Š”์ง€ ํ™•์ธ
if end_sample > waveform.size(1):
# ์งง์€ ๊ณก์˜ ๊ฒฝ์šฐ: ๋์—์„œ๋ถ€ํ„ฐ ๊ฑฐ๊พธ๋กœ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ ์‹œ๋„
if start_sample < waveform.size(1) and waveform.size(1) >= fixed_samples:
start_sample = waveform.size(1) - fixed_samples
end_sample = waveform.size(1)
else:
continue
# ์ •ํ™•ํžˆ fixed_samples ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
segment = waveform[:, start_sample:end_sample]
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ - ์ฑ„๋„ ์ฐจ์› ์œ ์ง€
filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0)
segments.append(filtered)
# ์ตœ๋Œ€ 48๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋งŒ ์‚ฌ์šฉ
if len(segments) >= 48:
break
# ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ, ๊ณก์ด ๋„ˆ๋ฌด ์งง์•„์„œ ๊ณ ์ • ๊ธธ์ด ์„ธ๊ทธ๋จผํŠธ๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ
if not segments:
if waveform.size(1) > 0: # ์˜ค๋””์˜ค๊ฐ€ ์žˆ์ง€๋งŒ ๋งค์šฐ ์งง์€ ๊ฒฝ์šฐ
# ์ „์ฒด ์˜ค๋””์˜ค๋ฅผ ํ•˜๋‚˜์˜ ์„ธ๊ทธ๋จผํŠธ๋กœ ์‚ฌ์šฉํ•˜๊ณ  ๋‚˜๋จธ์ง€๋Š” ์ œ๋กœ ํŒจ๋”ฉ
segment = waveform
# ํ•„์š”ํ•œ ๊ธธ์ด์— ๋งž๊ฒŒ ํŒจ๋”ฉ ์ถ”๊ฐ€
padding_length = fixed_samples - segment.size(1)
if padding_length > 0:
segment = torch.nn.functional.pad(segment, (0, padding_length))
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ
filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0)
segments.append(filtered)
else:
# ์™„์ „ํžˆ ๋นˆ ์˜ค๋””์˜ค์ผ ๊ฒฝ์šฐ
return torch.zeros((48, 1, fixed_samples), dtype=torch.float32), torch.ones(48, dtype=torch.bool)
# ์Šคํƒํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ - (n_segments, 1, time_samples) ํ˜•ํƒœ ์œ ์ง€
stacked_segments = torch.stack(segments)
# ์‹ค์ œ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ (ํŒจ๋”ฉ ์•„๋‹˜)
num_segments = stacked_segments.shape[0]
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์ƒ์„ฑ (False = ์‹ค์ œ ์˜ค๋””์˜ค, True = ํŒจ๋”ฉ)
padding_mask = torch.zeros(48, dtype=torch.bool)
# 48๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ํŒจ๋”ฉ ์ถ”๊ฐ€
if num_segments < 48:
# ๋นˆ ์„ธ๊ทธ๋จผํŠธ๋กœ ํŒจ๋”ฉ (zeros)
padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32)
stacked_segments = torch.cat([stacked_segments, padding], dim=0)
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์„ค์ • (True = ํŒจ๋”ฉ)
padding_mask[num_segments:] = True
return stacked_segments, padding_mask
def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict:
"""
Run inference on audio segments.
Args:
model: The loaded model
audio_segments: Preprocessed audio segments tensor (48, 1, 240000)
device: Device to run inference on
Returns:
Dictionary with prediction results
"""
model.eval()
model.to(device)
model = model.half()
print(padding_mask.shape)
with torch.no_grad():
# ๋ฐ์ดํ„ฐ ํ˜•ํƒœ ํ™•์ธ ๋ฐ ์กฐ์ •
# wav_collate_with_mask ํ•จ์ˆ˜์™€ ์ผ์น˜ํ•˜๋„๋ก ์ฒ˜๋ฆฌ
if audio_segments.shape[1] == 1: # (48, 1, 240000) ํ˜•ํƒœ
# ์ฑ„๋„ ์ฐจ์› ์ œ๊ฑฐํ•˜๊ณ  ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000)
else:
audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์‚ฌ์‹ค audio๊ฐ€ ์•„๋‹ˆ๋ผ embedding segments์ผ์ˆ˜๋„
# ๋ฐ์ดํ„ฐ๋ฅผ half ํƒ€์ž…์œผ๋กœ ๋ณ€ํ™˜
if padding_mask.dim() == 1:
padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
audio_segments = audio_segments.to(device).half()
mask = padding_mask.to(device)
print(f"Input shape: {audio_segments.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Mask: {mask}")
# ์ถ”๋ก  ์‹คํ–‰ (๋งˆ์Šคํฌ ํฌํ•จ)
outputs = model(audio_segments, mask)
print(f"Output type: {type(outputs)}")
print(f"Output: {outputs}")
# ๋ชจ๋ธ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ฒ˜๋ฆฌ
if isinstance(outputs, dict):
result = outputs
else:
# ๋‹จ์ผ ํ…์„œ์ธ ๊ฒฝ์šฐ (๋กœ์ง“)
logits = outputs.squeeze()
prob = torch.sigmoid(logits).item()
result = {
"prediction": "Fake" if prob > 0.5 else "Real",
"confidence": f"{max(prob, 1-prob)*100:.2f}%",
"fake_probability": f"{prob:.4f}",
"real_probability": f"{1-prob:.4f}",
"raw_output": logits.cpu().numpy().tolist()
}
return result
def get_model(model_type, device):
"""Load the specified model."""
if model_type == "MERT":
from ISMIR_2025.MERT.networks import CCV
model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
ckpt_file = 'mert_finetune_10.pth'
model.load_state_dict(torch.load(ckpt_file, map_location=device))
embed_dim = 768
else:
raise ValueError(f"Unknown model type: {model_type}")
model.eval()
return model, embed_dim
"""
elif model_type == "music2vec":
from ISMIR_2025.music2vec.networks import Music2VecClassifier
model = Music2VecClassifier(freeze_feature_extractor=True).to(device)
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/fakemusicretrain/musiv2vec_processor/finetune_10.pth'
embed_dim = 768
elif model_type == "wav2vec":
from ISMIR_2025.wav2vec.networks import Wav2Vec2ForFakeMusic
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/split/wav2vec_processor/wav2vec2_finetune_10.pth'
embed_dim = 768
elif model_type == "ccv":
from ISMIR_2025.Model.networks import CCV
model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/ckpt/datasplit/hp1000/best_model_CCV.pth'
embed_dim = 512
"""
def inference_with_audio(audio_path):
#audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3"
model_type = "MERT"
checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt"
device = 'cuda'
# Note: Model loading would be handled by your code
print(f"Loading model of type {model_type} from {checkpoint_path}")
backbone_model, input_dim = get_model(model_type, device)
segments, padding_mask = load_audio(audio_path, sr=24000)
segments = segments.to(device).to(torch.float32)
padding_mask = padding_mask.to(device).unsqueeze(0)
logits,embedding = backbone_model(segments.squeeze(1))
test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
test_data, test_target = test_dataset[0]
test_data = test_data.to(device).to(torch.float32)
test_target = test_target.to(device)
output, _ = backbone_model(test_data.unsqueeze(0))
# ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ์ถ”๊ฐ€
model = MusicAudioClassifier.load_from_checkpoint(
checkpoint_path,
input_dim=input_dim,
#emb_model=backbone_model
is_emb = True,
mode = 'both'
)
# Run inference
print(f"Segments shape: {segments.shape}")
print("Running inference...")
results = run_inference(model, embedding, padding_mask, device=device)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print(f"Results: {results}")
return str(results)
if __name__ == "__main__":
inference_with_audio()