Spaces:
Runtime error
Runtime error
import os | |
from pathlib import Path | |
import json | |
from beat_this.inference import File2Beats | |
import numpy as np | |
import torch | |
from typing import List, Tuple, Optional | |
import pytorch_lightning as pl | |
from model import MusicClassifier, MusicAudioClassifier | |
import torch | |
import torchaudio | |
import scipy.signal as signal | |
from typing import Dict, List | |
from dataset_f import FakeMusicCapsDataset | |
from preprocess import get_segments_from_wav, find_optimal_segment_length | |
def highpass_filter(y, sr, cutoff=1000, order=5): | |
if isinstance(sr, np.ndarray): | |
sr = np.mean(sr) | |
if not isinstance(sr, (int, float)): | |
raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}") | |
nyquist = 0.5 * sr | |
if cutoff <= 0 or cutoff >= nyquist: | |
cutoff = max(10, min(cutoff, nyquist - 1)) | |
normal_cutoff = cutoff / nyquist | |
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) | |
y_filtered = signal.lfilter(b, a, y) | |
return y_filtered | |
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
์ค๋์ค ํ์ผ์ ๋ถ๋ฌ์ ์ธ๊ทธ๋จผํธ๋ก ๋ถํ ํฉ๋๋ค. | |
๊ณ ์ ๋ ๊ธธ์ด์ ์ธ๊ทธ๋จผํธ๋ฅผ ์ต๋ 48๊ฐ ์ถ์ถํ๊ณ , ๋ถ์กฑํ ๊ฒฝ์ฐ ํจ๋ฉ์ ์ถ๊ฐํฉ๋๋ค. | |
Args: | |
audio_path: ์ค๋์ค ํ์ผ ๊ฒฝ๋ก | |
sr: ๋ชฉํ ์ํ๋ง ๋ ์ดํธ (๊ธฐ๋ณธ๊ฐ 24000) | |
Returns: | |
Tuple containing: | |
- ์ค๋์ค ํํ์ด ๋ด๊ธด ํ ์ (48, 1, 240000) | |
- ํจ๋ฉ ๋ง์คํฌ ํ ์ (48), True = ํจ๋ฉ, False = ์ค์ ์ค๋์ค | |
""" | |
beats, downbeats = get_segments_from_wav(audio_path) | |
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats) | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# ๋ฐ์ดํฐ ํ์ ์ float32๋ก ๋ณํ | |
waveform = waveform.to(torch.float32) | |
if sample_rate != sr: | |
resampler = torchaudio.transforms.Resample(sample_rate, sr) | |
waveform = resampler(waveform) | |
# ๋ชจ๋ ธ๋ก ๋ณํ (ํ์ํ ๊ฒฝ์ฐ) | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
# 240000 ์ํ = 10์ด @ 24kHz | |
fixed_samples = 240000 | |
# ๊ฐ downbeat์์ ์์ํ๋ segment ์์ฑ | |
segments = [] | |
# ๋ค์ด๋นํธ๊ฐ ์๊ฑฐ๋ ๋งค์ฐ ์ ์ ๊ฒฝ์ฐ ์ ์ฒด ์ค๋์ค๋ฅผ ์ผ์ ๊ฐ๊ฒฉ์ผ๋ก ๋ถํ | |
if len(cleaned_downbeats) < 2: | |
# ์ค๋์ค ์ด ๊ธธ์ด (์ด) | |
total_duration = waveform.size(1) / sr | |
# 5์ด ๊ฐ๊ฒฉ์ผ๋ก ์ธ๊ทธ๋จผํธ ์์์ ์์ฑ (๋๋ ๋ ์งง์ ๊ฐ๊ฒฉ์ผ๋ก ์ค์ ๊ฐ๋ฅ) | |
# 240000 ์ํ์ 10์ด์ด๋ฏ๋ก 5์ด ๊ฐ๊ฒฉ์ ์ธ๊ทธ๋จผํธ ๊ฐ 50% ์ค๋ฒ๋ฉ | |
segment_interval = 5.0 # ์ด ๋จ์ | |
# ์์ ์๊ฐ ๋ชฉ๋ก ์์ฑ (0์ด๋ถํฐ ์์) | |
start_times = [t for t in np.arange(0, total_duration - (fixed_samples/sr) + 0.01, segment_interval)] | |
# ์ต์ํ ํ๋์ ์ธ๊ทธ๋จผํธ๋ ๋ณด์ฅ | |
if not start_times and total_duration > 0: | |
start_times = [0.0] | |
else: | |
# ๊ธฐ์กด ๋ฐฉ์๋๋ก ๋ค์ด๋นํธ ์ฌ์ฉ | |
start_times = cleaned_downbeats | |
# ์ธ๊ทธ๋จผํธ ์ถ์ถ | |
for i, start_time in enumerate(start_times): | |
# ์์ ์ํ ์ธ๋ฑ์ค ๊ณ์ฐ | |
start_sample = int(start_time * sr) | |
# ๋ ์ํ ์ธ๋ฑ์ค ๊ณ์ฐ (์์ ์ง์ + ๊ณ ์ ๊ธธ์ด) | |
end_sample = start_sample + fixed_samples | |
# ํ์ผ ๋์ ๋์ด๊ฐ๋์ง ํ์ธ | |
if end_sample > waveform.size(1): | |
# ์งง์ ๊ณก์ ๊ฒฝ์ฐ: ๋์์๋ถํฐ ๊ฑฐ๊พธ๋ก ์ธ๊ทธ๋จผํธ ์ถ์ถ ์๋ | |
if start_sample < waveform.size(1) and waveform.size(1) >= fixed_samples: | |
start_sample = waveform.size(1) - fixed_samples | |
end_sample = waveform.size(1) | |
else: | |
continue | |
# ์ ํํ fixed_samples ๊ธธ์ด์ ์ธ๊ทธ๋จผํธ ์ถ์ถ | |
segment = waveform[:, start_sample:end_sample] | |
# ํ์ดํจ์ค ํํฐ ์ ์ฉ - ์ฑ๋ ์ฐจ์ ์ ์ง | |
filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) | |
segments.append(filtered) | |
# ์ต๋ 48๊ฐ ์ธ๊ทธ๋จผํธ๋ง ์ฌ์ฉ | |
if len(segments) >= 48: | |
break | |
# ์ธ๊ทธ๋จผํธ๊ฐ ์๋ ๊ฒฝ์ฐ, ๊ณก์ด ๋๋ฌด ์งง์์ ๊ณ ์ ๊ธธ์ด ์ธ๊ทธ๋จผํธ๋ฅผ ๋ง๋ค ์ ์๋ ๊ฒฝ์ฐ | |
if not segments: | |
if waveform.size(1) > 0: # ์ค๋์ค๊ฐ ์์ง๋ง ๋งค์ฐ ์งง์ ๊ฒฝ์ฐ | |
# ์ ์ฒด ์ค๋์ค๋ฅผ ํ๋์ ์ธ๊ทธ๋จผํธ๋ก ์ฌ์ฉํ๊ณ ๋๋จธ์ง๋ ์ ๋ก ํจ๋ฉ | |
segment = waveform | |
# ํ์ํ ๊ธธ์ด์ ๋ง๊ฒ ํจ๋ฉ ์ถ๊ฐ | |
padding_length = fixed_samples - segment.size(1) | |
if padding_length > 0: | |
segment = torch.nn.functional.pad(segment, (0, padding_length)) | |
# ํ์ดํจ์ค ํํฐ ์ ์ฉ | |
filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) | |
segments.append(filtered) | |
else: | |
# ์์ ํ ๋น ์ค๋์ค์ผ ๊ฒฝ์ฐ | |
return torch.zeros((48, 1, fixed_samples), dtype=torch.float32), torch.ones(48, dtype=torch.bool) | |
# ์คํํ์ฌ ํ ์๋ก ๋ณํ - (n_segments, 1, time_samples) ํํ ์ ์ง | |
stacked_segments = torch.stack(segments) | |
# ์ค์ ์ธ๊ทธ๋จผํธ ์ (ํจ๋ฉ ์๋) | |
num_segments = stacked_segments.shape[0] | |
# ํจ๋ฉ ๋ง์คํฌ ์์ฑ (False = ์ค์ ์ค๋์ค, True = ํจ๋ฉ) | |
padding_mask = torch.zeros(48, dtype=torch.bool) | |
# 48๊ฐ ๋ฏธ๋ง์ธ ๊ฒฝ์ฐ ํจ๋ฉ ์ถ๊ฐ | |
if num_segments < 48: | |
# ๋น ์ธ๊ทธ๋จผํธ๋ก ํจ๋ฉ (zeros) | |
padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32) | |
stacked_segments = torch.cat([stacked_segments, padding], dim=0) | |
# ํจ๋ฉ ๋ง์คํฌ ์ค์ (True = ํจ๋ฉ) | |
padding_mask[num_segments:] = True | |
return stacked_segments, padding_mask | |
def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict: | |
""" | |
Run inference on audio segments. | |
Args: | |
model: The loaded model | |
audio_segments: Preprocessed audio segments tensor (48, 1, 240000) | |
device: Device to run inference on | |
Returns: | |
Dictionary with prediction results | |
""" | |
model.eval() | |
model.to(device) | |
model = model.half() | |
print(padding_mask.shape) | |
with torch.no_grad(): | |
# ๋ฐ์ดํฐ ํํ ํ์ธ ๋ฐ ์กฐ์ | |
# wav_collate_with_mask ํจ์์ ์ผ์นํ๋๋ก ์ฒ๋ฆฌ | |
if audio_segments.shape[1] == 1: # (48, 1, 240000) ํํ | |
# ์ฑ๋ ์ฐจ์ ์ ๊ฑฐํ๊ณ ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ | |
audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000) | |
else: | |
audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์ฌ์ค audio๊ฐ ์๋๋ผ embedding segments์ผ์๋ | |
# ๋ฐ์ดํฐ๋ฅผ half ํ์ ์ผ๋ก ๋ณํ | |
if padding_mask.dim() == 1: | |
padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48] | |
audio_segments = audio_segments.to(device).half() | |
mask = padding_mask.to(device) | |
print(f"Input shape: {audio_segments.shape}") | |
print(f"Mask shape: {mask.shape}") | |
print(f"Mask: {mask}") | |
# ์ถ๋ก ์คํ (๋ง์คํฌ ํฌํจ) | |
outputs = model(audio_segments, mask) | |
print(f"Output type: {type(outputs)}") | |
print(f"Output: {outputs}") | |
# ๋ชจ๋ธ ์ถ๋ ฅ ๊ตฌ์กฐ์ ๋ฐ๋ผ ์ฒ๋ฆฌ | |
if isinstance(outputs, dict): | |
result = outputs | |
else: | |
# ๋จ์ผ ํ ์์ธ ๊ฒฝ์ฐ (๋ก์ง) | |
logits = outputs.squeeze() | |
prob = torch.sigmoid(logits).item() | |
result = { | |
"prediction": "Fake" if prob > 0.5 else "Real", | |
"confidence": f"{max(prob, 1-prob)*100:.2f}%", | |
"fake_probability": f"{prob:.4f}", | |
"real_probability": f"{1-prob:.4f}", | |
"raw_output": logits.cpu().numpy().tolist() | |
} | |
return result | |
def get_model(model_type, device): | |
"""Load the specified model.""" | |
if model_type == "MERT": | |
from ISMIR_2025.MERT.networks import CCV | |
model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) | |
ckpt_file = 'mert_finetune_10.pth' | |
model.load_state_dict(torch.load(ckpt_file, map_location=device)) | |
embed_dim = 768 | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
model.eval() | |
return model, embed_dim | |
""" | |
elif model_type == "music2vec": | |
from ISMIR_2025.music2vec.networks import Music2VecClassifier | |
model = Music2VecClassifier(freeze_feature_extractor=True).to(device) | |
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/fakemusicretrain/musiv2vec_processor/finetune_10.pth' | |
embed_dim = 768 | |
elif model_type == "wav2vec": | |
from ISMIR_2025.wav2vec.networks import Wav2Vec2ForFakeMusic | |
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) | |
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/split/wav2vec_processor/wav2vec2_finetune_10.pth' | |
embed_dim = 768 | |
elif model_type == "ccv": | |
from ISMIR_2025.Model.networks import CCV | |
model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) | |
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/ckpt/datasplit/hp1000/best_model_CCV.pth' | |
embed_dim = 512 | |
""" | |
def inference_with_audio(audio_path): | |
#audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3" | |
model_type = "MERT" | |
checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt" | |
device = 'cuda' | |
# Note: Model loading would be handled by your code | |
print(f"Loading model of type {model_type} from {checkpoint_path}") | |
backbone_model, input_dim = get_model(model_type, device) | |
segments, padding_mask = load_audio(audio_path, sr=24000) | |
segments = segments.to(device).to(torch.float32) | |
padding_mask = padding_mask.to(device).unsqueeze(0) | |
logits,embedding = backbone_model(segments.squeeze(1)) | |
test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0) | |
test_data, test_target = test_dataset[0] | |
test_data = test_data.to(device).to(torch.float32) | |
test_target = test_target.to(device) | |
output, _ = backbone_model(test_data.unsqueeze(0)) | |
# ๋ชจ๋ธ ๋ก๋ ๋ถ๋ถ ์ถ๊ฐ | |
model = MusicAudioClassifier.load_from_checkpoint( | |
checkpoint_path, | |
input_dim=input_dim, | |
#emb_model=backbone_model | |
is_emb = True, | |
mode = 'both' | |
) | |
# Run inference | |
print(f"Segments shape: {segments.shape}") | |
print("Running inference...") | |
results = run_inference(model, embedding, padding_mask, device=device) | |
# ๊ฒฐ๊ณผ ์ถ๋ ฅ | |
print(f"Results: {results}") | |
return str(results) | |
if __name__ == "__main__": | |
inference_with_audio() | |