import os from pathlib import Path import json from beat_this.inference import File2Beats import numpy as np import torch from typing import List, Tuple, Optional import pytorch_lightning as pl from model import MusicClassifier, MusicAudioClassifier import torch import torchaudio import scipy.signal as signal from typing import Dict, List from dataset_f import FakeMusicCapsDataset from preprocess import get_segments_from_wav, find_optimal_segment_length def highpass_filter(y, sr, cutoff=1000, order=5): if isinstance(sr, np.ndarray): sr = np.mean(sr) if not isinstance(sr, (int, float)): raise ValueError(f"sr must be a number, but got {type(sr)}: {sr}") nyquist = 0.5 * sr if cutoff <= 0 or cutoff >= nyquist: cutoff = max(10, min(cutoff, nyquist - 1)) normal_cutoff = cutoff / nyquist b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) y_filtered = signal.lfilter(b, a, y) return y_filtered def load_audio(audio_path: str, sr: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: """ 오디오 파일을 불러와 세그먼트로 분할합니다. 고정된 길이의 세그먼트를 최대 48개 추출하고, 부족한 경우 패딩을 추가합니다. Args: audio_path: 오디오 파일 경로 sr: 목표 샘플링 레이트 (기본값 24000) Returns: Tuple containing: - 오디오 파형이 담긴 텐서 (48, 1, 240000) - 패딩 마스크 텐서 (48), True = 패딩, False = 실제 오디오 """ beats, downbeats = get_segments_from_wav(audio_path) optimal_length, cleaned_downbeats = find_optimal_segment_length(downbeats) waveform, sample_rate = torchaudio.load(audio_path) # 데이터 타입을 float32로 변환 waveform = waveform.to(torch.float32) if sample_rate != sr: resampler = torchaudio.transforms.Resample(sample_rate, sr) waveform = resampler(waveform) # 모노로 변환 (필요한 경우) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # 240000 샘플 = 10초 @ 24kHz fixed_samples = 240000 # 각 downbeat에서 시작하는 segment 생성 segments = [] # 다운비트가 없거나 매우 적을 경우 전체 오디오를 일정 간격으로 분할 if len(cleaned_downbeats) < 2: # 오디오 총 길이 (초) total_duration = waveform.size(1) / sr # 5초 간격으로 세그먼트 시작점 생성 (또는 더 짧은 간격으로 설정 가능) # 240000 샘플은 10초이므로 5초 간격은 세그먼트 간 50% 오버랩 segment_interval = 5.0 # 초 단위 # 시작 시간 목록 생성 (0초부터 시작) start_times = [t for t in np.arange(0, total_duration - (fixed_samples/sr) + 0.01, segment_interval)] # 최소한 하나의 세그먼트는 보장 if not start_times and total_duration > 0: start_times = [0.0] else: # 기존 방식대로 다운비트 사용 start_times = cleaned_downbeats # 세그먼트 추출 for i, start_time in enumerate(start_times): # 시작 샘플 인덱스 계산 start_sample = int(start_time * sr) # 끝 샘플 인덱스 계산 (시작 지점 + 고정 길이) end_sample = start_sample + fixed_samples # 파일 끝을 넘어가는지 확인 if end_sample > waveform.size(1): # 짧은 곡의 경우: 끝에서부터 거꾸로 세그먼트 추출 시도 if start_sample < waveform.size(1) and waveform.size(1) >= fixed_samples: start_sample = waveform.size(1) - fixed_samples end_sample = waveform.size(1) else: continue # 정확히 fixed_samples 길이의 세그먼트 추출 segment = waveform[:, start_sample:end_sample] # 하이패스 필터 적용 - 채널 차원 유지 filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) segments.append(filtered) # 최대 48개 세그먼트만 사용 if len(segments) >= 48: break # 세그먼트가 없는 경우, 곡이 너무 짧아서 고정 길이 세그먼트를 만들 수 없는 경우 if not segments: if waveform.size(1) > 0: # 오디오가 있지만 매우 짧은 경우 # 전체 오디오를 하나의 세그먼트로 사용하고 나머지는 제로 패딩 segment = waveform # 필요한 길이에 맞게 패딩 추가 padding_length = fixed_samples - segment.size(1) if padding_length > 0: segment = torch.nn.functional.pad(segment, (0, padding_length)) # 하이패스 필터 적용 filtered = torch.tensor(highpass_filter(segment.squeeze().numpy(), sr)).unsqueeze(0) segments.append(filtered) else: # 완전히 빈 오디오일 경우 return torch.zeros((48, 1, fixed_samples), dtype=torch.float32), torch.ones(48, dtype=torch.bool) # 스택하여 텐서로 변환 - (n_segments, 1, time_samples) 형태 유지 stacked_segments = torch.stack(segments) # 실제 세그먼트 수 (패딩 아님) num_segments = stacked_segments.shape[0] # 패딩 마스크 생성 (False = 실제 오디오, True = 패딩) padding_mask = torch.zeros(48, dtype=torch.bool) # 48개 미만인 경우 패딩 추가 if num_segments < 48: # 빈 세그먼트로 패딩 (zeros) padding = torch.zeros((48 - num_segments, 1, fixed_samples), dtype=torch.float32) stacked_segments = torch.cat([stacked_segments, padding], dim=0) # 패딩 마스크 설정 (True = 패딩) padding_mask[num_segments:] = True return stacked_segments, padding_mask def run_inference(model, audio_segments: torch.Tensor, padding_mask: torch.Tensor, device: str = 'cuda' if torch.cuda.is_available() else 'cpu') -> Dict: """ Run inference on audio segments. Args: model: The loaded model audio_segments: Preprocessed audio segments tensor (48, 1, 240000) device: Device to run inference on Returns: Dictionary with prediction results """ model.eval() model.to(device) model = model.half() print(padding_mask.shape) with torch.no_grad(): # 데이터 형태 확인 및 조정 # wav_collate_with_mask 함수와 일치하도록 처리 if audio_segments.shape[1] == 1: # (48, 1, 240000) 형태 # 채널 차원 제거하고 배치 차원 추가 audio_segments = audio_segments[:, 0, :].unsqueeze(0) # (1, 48, 240000) else: audio_segments = audio_segments.unsqueeze(0) # (1, 48, 768) # 사실 audio가 아니라 embedding segments일수도 # 데이터를 half 타입으로 변환 if padding_mask.dim() == 1: padding_mask = padding_mask.unsqueeze(0) # [48] -> [1, 48] audio_segments = audio_segments.to(device).half() mask = padding_mask.to(device) print(f"Input shape: {audio_segments.shape}") print(f"Mask shape: {mask.shape}") print(f"Mask: {mask}") # 추론 실행 (마스크 포함) outputs = model(audio_segments, mask) print(f"Output type: {type(outputs)}") print(f"Output: {outputs}") # 모델 출력 구조에 따라 처리 if isinstance(outputs, dict): result = outputs else: # 단일 텐서인 경우 (로짓) logits = outputs.squeeze() prob = torch.sigmoid(logits).item() result = { "prediction": "Fake" if prob > 0.5 else "Real", "confidence": f"{max(prob, 1-prob)*100:.2f}%", "fake_probability": f"{prob:.4f}", "real_probability": f"{1-prob:.4f}", "raw_output": logits.cpu().numpy().tolist() } return result def get_model(model_type, device): """Load the specified model.""" if model_type == "MERT": from ISMIR_2025.MERT.networks import CCV model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) ckpt_file = 'mert_finetune_10.pth' model.load_state_dict(torch.load(ckpt_file, map_location=device)) embed_dim = 768 else: raise ValueError(f"Unknown model type: {model_type}") model.eval() return model, embed_dim """ elif model_type == "music2vec": from ISMIR_2025.music2vec.networks import Music2VecClassifier model = Music2VecClassifier(freeze_feature_extractor=True).to(device) ckpt_file = '/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/fakemusicretrain/musiv2vec_processor/finetune_10.pth' embed_dim = 768 elif model_type == "wav2vec": from ISMIR_2025.wav2vec.networks import Wav2Vec2ForFakeMusic model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) ckpt_file = '/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/split/wav2vec_processor/wav2vec2_finetune_10.pth' embed_dim = 768 elif model_type == "ccv": from ISMIR_2025.Model.networks import CCV model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) ckpt_file = '/data/kym/AI_Music_Detection/Code/model/ckpt/datasplit/hp1000/best_model_CCV.pth' embed_dim = 512 """ def inference_with_audio(audio_path): #audio_path = "The Chainsmokers & Coldplay - Something Just Like This (Lyric).mp3" model_type = "MERT" checkpoint_path = "with_embedding_MERT_768_embedding/EmbeddingModel_MERT_768-epoch=0353-val_loss=0.3866-val_acc=0.9809-val_f1=0.9803-val_precision=0.9764-val_recall=0.9842.ckpt" device = 'cuda' # Note: Model loading would be handled by your code print(f"Loading model of type {model_type} from {checkpoint_path}") backbone_model, input_dim = get_model(model_type, device) segments, padding_mask = load_audio(audio_path, sr=24000) segments = segments.to(device).to(torch.float32) padding_mask = padding_mask.to(device).unsqueeze(0) logits,embedding = backbone_model(segments.squeeze(1)) test_dataset = FakeMusicCapsDataset([audio_path], [0], target_duration=10.0) test_data, test_target = test_dataset[0] test_data = test_data.to(device).to(torch.float32) test_target = test_target.to(device) output, _ = backbone_model(test_data.unsqueeze(0)) # 모델 로드 부분 추가 model = MusicAudioClassifier.load_from_checkpoint( checkpoint_path, input_dim=input_dim, #emb_model=backbone_model is_emb = True, mode = 'both' ) # Run inference print(f"Segments shape: {segments.shape}") print("Running inference...") results = run_inference(model, embedding, padding_mask, device=device) # 결과 출력 print(f"Results: {results}") return str(results) if __name__ == "__main__": inference_with_audio()