aimusicdetection / inference.py
nininigold's picture
Upload folder using huggingface_hub
d4ac163 verified
import os
from pathlib import Path
import json
from beat_this.inference import File2Beats
import numpy as np
import torch
from typing import List, Tuple, Optional
import pytorch_lightning as pl
from model import MusicClassifier, MusicAudioClassifier
import torch
import torchaudio
import scipy.signal as signal
from typing import Dict, List
from dataset_f import FakeMusicCapsDataset
from preprocess import get_segments_from_wav, find_optimal_segment_length
def highpass_filter(y, sr, cutoff=1000, order=5):
if isinstance(sr, np.ndarray):
sr = np.mean(sr)
if not isinstance(sr, (int, float)):
raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}")
nyquist = 0.5 * sr
if cutoff <= 0 or cutoff >= nyquist:
cutoff = max(10, min(cutoff, nyquist - 1))
normal_cutoff = cutoff / nyquist
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
y_filtered = signal.lfilter(b, a, y)
return y_filtered
def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
"""
์˜ค๋””์˜ค ํŒŒ์ผ์„ ๋ถˆ๋Ÿฌ์™€ ์„ธ๊ทธ๋จผํŠธ๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
๊ณ ์ •๋œ ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ๋ฅผ ์ตœ๋Œ€ 48๊ฐœ ์ถ”์ถœํ•˜๊ณ , ๋ถ€์กฑํ•œ ๊ฒฝ์šฐ ํŒจ๋”ฉ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
Args:
audio_path: ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
sr: ๋ชฉํ‘œ ์ƒ˜ํ”Œ๋ง ๋ ˆ์ดํŠธ (๊ธฐ๋ณธ๊ฐ’ 24000)
Returns:
Tuple containing:
- ์˜ค๋””์˜ค ํŒŒํ˜•์ด ๋‹ด๊ธด ํ…์„œ (48, 1, 240000)
- ํŒจ๋”ฉ ๋งˆ์Šคํฌ ํ…์„œ (48), True = ํŒจ๋”ฉ, False = ์‹ค์ œ ์˜ค๋””์˜ค
"""
beats, downbeats = get_segments_from_wav(audio_path)
optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats)
waveform, sample_rate = torchaudio.load(audio_path)
# ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ float32๋กœ ๋ณ€ํ™˜
waveform = waveform.to(torch.float32)
if sample_rate != sr:
resampler = torchaudio.transforms.Resample(sample_rate, sr)
waveform = resampler(waveform)
# ๋ชจ๋…ธ๋กœ ๋ณ€ํ™˜ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
# 240000 ์ƒ˜ํ”Œ = 10์ดˆ @ 24kHz
fixed_samples = 240000
# ๊ฐ downbeat์—์„œ ์‹œ์ž‘ํ•˜๋Š” segment ์ƒ์„ฑ
segments = []
# ๋‹ค์šด๋น„ํŠธ๊ฐ€ ์—†๊ฑฐ๋‚˜ ๋งค์šฐ ์ ์„ ๊ฒฝ์šฐ ์ „์ฒด ์˜ค๋””์˜ค๋ฅผ ์ผ์ • ๊ฐ„๊ฒฉ์œผ๋กœ ๋ถ„ํ• 
if len(cleaned_downbeats) < 2:
# ์˜ค๋””์˜ค ์ด ๊ธธ์ด (์ดˆ)
total_duration = waveform.size(1) / sr
# 5์ดˆ ๊ฐ„๊ฒฉ์œผ๋กœ ์„ธ๊ทธ๋จผํŠธ ์‹œ์ž‘์  ์ƒ์„ฑ (๋˜๋Š” ๋” ์งง์€ ๊ฐ„๊ฒฉ์œผ๋กœ ์„ค์ • ๊ฐ€๋Šฅ)
# 240000 ์ƒ˜ํ”Œ์€ 10์ดˆ์ด๋ฏ€๋กœ 5์ดˆ ๊ฐ„๊ฒฉ์€ ์„ธ๊ทธ๋จผํŠธ ๊ฐ„ 50% ์˜ค๋ฒ„๋žฉ
segment_interval = 5.0 # ์ดˆ ๋‹จ์œ„
# ์‹œ์ž‘ ์‹œ๊ฐ„ ๋ชฉ๋ก ์ƒ์„ฑ (0์ดˆ๋ถ€ํ„ฐ ์‹œ์ž‘)
start_times = [t for t in np.arange(0, total_duration - (fixed_samples/sr) + 0.01, segment_interval)]
# ์ตœ์†Œํ•œ ํ•˜๋‚˜์˜ ์„ธ๊ทธ๋จผํŠธ๋Š” ๋ณด์žฅ
if not start_times and total_duration > 0:
start_times = [0.0]
else:
# ๊ธฐ์กด ๋ฐฉ์‹๋Œ€๋กœ ๋‹ค์šด๋น„ํŠธ ์‚ฌ์šฉ
start_times = cleaned_downbeats
# ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
for i, start_time in enumerate(start_times):
# ์‹œ์ž‘ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ
start_sample = int(start_time * sr)
# ๋ ์ƒ˜ํ”Œ ์ธ๋ฑ์Šค ๊ณ„์‚ฐ (์‹œ์ž‘ ์ง€์  + ๊ณ ์ • ๊ธธ์ด)
end_sample = start_sample + fixed_samples
# ํŒŒ์ผ ๋์„ ๋„˜์–ด๊ฐ€๋Š”์ง€ ํ™•์ธ
if end_sample > waveform.size(1):
# ์งง์€ ๊ณก์˜ ๊ฒฝ์šฐ: ๋์—์„œ๋ถ€ํ„ฐ ๊ฑฐ๊พธ๋กœ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ ์‹œ๋„
if start_sample < waveform.size(1) and waveform.size(1) >= fixed_samples:
start_sample = waveform.size(1) - fixed_samples
end_sample = waveform.size(1)
else:
continue
# ์ •ํ™•ํžˆ fixed_samples ๊ธธ์ด์˜ ์„ธ๊ทธ๋จผํŠธ ์ถ”์ถœ
segment = waveform[:, start_sample:end_sample]
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ - ์ฑ„๋„ ์ฐจ์› ์œ ์ง€
filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0)
segments.append(filtered)
# ์ตœ๋Œ€ 48๊ฐœ ์„ธ๊ทธ๋จผํŠธ๋งŒ ์‚ฌ์šฉ
if len(segments) >= 48:
break
# ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ, ๊ณก์ด ๋„ˆ๋ฌด ์งง์•„์„œ ๊ณ ์ • ๊ธธ์ด ์„ธ๊ทธ๋จผํŠธ๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ
if not segments:
if waveform.size(1) > 0: # ์˜ค๋””์˜ค๊ฐ€ ์žˆ์ง€๋งŒ ๋งค์šฐ ์งง์€ ๊ฒฝ์šฐ
# ์ „์ฒด ์˜ค๋””์˜ค๋ฅผ ํ•˜๋‚˜์˜ ์„ธ๊ทธ๋จผํŠธ๋กœ ์‚ฌ์šฉํ•˜๊ณ  ๋‚˜๋จธ์ง€๋Š” ์ œ๋กœ ํŒจ๋”ฉ
segment = waveform
# ํ•„์š”ํ•œ ๊ธธ์ด์— ๋งž๊ฒŒ ํŒจ๋”ฉ ์ถ”๊ฐ€
padding_length = fixed_samples - segment.size(1)
if padding_length > 0:
segment = torch.nn.functional.pad(segment, (0, padding_length))
# ํ•˜์ดํŒจ์Šค ํ•„ํ„ฐ ์ ์šฉ
filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0)
segments.append(filtered)
else:
# ์™„์ „ํžˆ ๋นˆ ์˜ค๋””์˜ค์ผ ๊ฒฝ์šฐ
return torch.zeros((48, 1, fixed_samples), dtype=torch.float32), torch.ones(48, dtype=torch.bool)
# ์Šคํƒํ•˜์—ฌ ํ…์„œ๋กœ ๋ณ€ํ™˜ - (n_segments, 1, time_samples) ํ˜•ํƒœ ์œ ์ง€
stacked_segments = torch.stack(segments)
# ์‹ค์ œ ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ (ํŒจ๋”ฉ ์•„๋‹˜)
num_segments = stacked_segments.shape[0]
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์ƒ์„ฑ (False = ์‹ค์ œ ์˜ค๋””์˜ค, True = ํŒจ๋”ฉ)
padding_mask = torch.zeros(48, dtype=torch.bool)
# 48๊ฐœ ๋ฏธ๋งŒ์ธ ๊ฒฝ์šฐ ํŒจ๋”ฉ ์ถ”๊ฐ€
if num_segments < 48:
# ๋นˆ ์„ธ๊ทธ๋จผํŠธ๋กœ ํŒจ๋”ฉ (zeros)
padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32)
stacked_segments = torch.cat([stacked_segments, padding], dim=0)
# ํŒจ๋”ฉ ๋งˆ์Šคํฌ ์„ค์ • (True = ํŒจ๋”ฉ)
padding_mask[num_segments:] = True
return stacked_segments, padding_mask
def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict:
"""
Run inference on audio segments.
Args:
model: The loaded model
audio_segments: Preprocessed audio segments tensor (48, 1, 240000)
device: Device to run inference on
Returns:
Dictionary with prediction results
"""
model.eval()
model.to(device)
model = model.half()
print(padding_mask.shape)
with torch.no_grad():
# ๋ฐ์ดํ„ฐ ํ˜•ํƒœ ํ™•์ธ ๋ฐ ์กฐ์ •
# wav_collate_with_mask ํ•จ์ˆ˜์™€ ์ผ์น˜ํ•˜๋„๋ก ์ฒ˜๋ฆฌ
if audio_segments.shape[1] == 1: # (48, 1, 240000) ํ˜•ํƒœ
# ์ฑ„๋„ ์ฐจ์› ์ œ๊ฑฐํ•˜๊ณ  ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000)
else:
audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # ์‚ฌ์‹ค audio๊ฐ€ ์•„๋‹ˆ๋ผ embedding segments์ผ์ˆ˜๋„
# ๋ฐ์ดํ„ฐ๋ฅผ half ํƒ€์ž…์œผ๋กœ ๋ณ€ํ™˜
if padding_mask.dim() == 1:
padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48]
audio_segments = audio_segments.to(device).half()
mask = padding_mask.to(device)
print(f"Input shape: {audio_segments.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Mask: {mask}")
# ์ถ”๋ก  ์‹คํ–‰ (๋งˆ์Šคํฌ ํฌํ•จ)
outputs = model(audio_segments, mask)
print(f"Output type: {type(outputs)}")
print(f"Output: {outputs}")
# ๋ชจ๋ธ ์ถœ๋ ฅ ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ฒ˜๋ฆฌ
if isinstance(outputs, dict):
result = outputs
else:
# ๋‹จ์ผ ํ…์„œ์ธ ๊ฒฝ์šฐ (๋กœ์ง“)
logits = outputs.squeeze()
prob = torch.sigmoid(logits).item()
result = {
"prediction": "Fake" if prob > 0.5 else "Real",
"confidence": f"{max(prob, 1-prob)*100:.2f}%",
"fake_probability": f"{prob:.4f}",
"real_probability": f"{1-prob:.4f}",
"raw_output": logits.cpu().numpy().tolist()
}
return result
def get_model(model_type, device):
"""Load the specified model."""
if model_type == "MERT":
from ISMIR_2025.MERT.networks import CCV
model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
ckpt_file = 'mert_finetune_10.pth'
model.load_state_dict(torch.load(ckpt_file, map_location=device))
embed_dim = 768
else:
raise ValueError(f"Unknown model type: {model_type}")
model.eval()
return model, embed_dim
"""
elif model_type == "music2vec":
from ISMIR_2025.music2vec.networks import Music2VecClassifier
model = Music2VecClassifier(freeze_feature_extractor=True).to(device)
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/fakemusicretrain/musiv2vec_processor/finetune_10.pth'
embed_dim = 768
elif model_type == "wav2vec":
from ISMIR_2025.wav2vec.networks import Wav2Vec2ForFakeMusic
model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device)
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/split/wav2vec_processor/wav2vec2_finetune_10.pth'
embed_dim = 768
elif model_type == "ccv":
from ISMIR_2025.Model.networks import CCV
model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device)
ckpt_file = '/data/kym/AI_Music_Detection/Code/model/ckpt/datasplit/hp1000/best_model_CCV.pth'
embed_dim = 512
"""
def inference_with_audio(audio_path):
#audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3"
model_type = "MERT"
checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt"
device = 'cuda'
# Note: Model loading would be handled by your code
print(f"Loading model of type {model_type} from {checkpoint_path}")
backbone_model, input_dim = get_model(model_type, device)
segments, padding_mask = load_audio(audio_path, sr=24000)
segments = segments.to(device).to(torch.float32)
padding_mask = padding_mask.to(device).unsqueeze(0)
logits,embedding = backbone_model(segments.squeeze(1))
test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0)
test_data, test_target = test_dataset[0]
test_data = test_data.to(device).to(torch.float32)
test_target = test_target.to(device)
output, _ = backbone_model(test_data.unsqueeze(0))
# ๋ชจ๋ธ ๋กœ๋“œ ๋ถ€๋ถ„ ์ถ”๊ฐ€
model = MusicAudioClassifier.load_from_checkpoint(
checkpoint_path,
input_dim=input_dim,
#emb_model=backbone_model
is_emb = True,
mode = 'both'
)
# Run inference
print(f"Segments shape: {segments.shape}")
print("Running inference...")
results = run_inference(model, embedding, padding_mask, device=device)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print(f"Results: {results}")
return str(results)
if __name__ == "__main__":
inference_with_audio()