Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import json | |
import torch | |
import torchaudio | |
import numpy as np | |
import logging | |
import warnings | |
import subprocess | |
import math | |
import random | |
import time | |
from pathlib import Path | |
from tqdm import tqdm | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
from omegaconf import DictConfig | |
import hydra | |
from hydra.utils import to_absolute_path | |
from transformers import Wav2Vec2FeatureExtractor, AutoModel | |
import mir_eval | |
import pretty_midi as pm | |
import gradio as gr | |
from gradio import Markdown | |
from music21 import converter | |
import torchaudio.transforms as T | |
import matplotlib.pyplot as plt | |
# カスタムユーティリティのインポート | |
from utils import logger | |
from utils.btc_model import BTC_model | |
from utils.transformer_modules import * | |
from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask | |
from utils.hparams import HParams | |
from utils.mir_eval_modules import ( | |
audio_file_to_features, idx2chord, idx2voca_chord, | |
get_audio_paths, get_lab_paths | |
) | |
from utils.mert import FeatureExtractorMERT | |
from model.linear_mt_attn_ck import FeedforwardModelMTAttnCK | |
# 不要な警告・ログを抑制 | |
warnings.filterwarnings("ignore") | |
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) | |
PITCH_CLASS = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] | |
tonic_signatures = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"] | |
mode_signatures = ["major", "minor"] | |
pitch_num_dic = { | |
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, | |
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11 | |
} | |
minor_major_dic = { | |
'D-':'C#', 'E-':'D#', 'G-':'F#', 'A-':'G#', 'B-':'A#' | |
} | |
minor_major_dic2 = { | |
'Db':'C#', 'Eb':'D#', 'Gb':'F#', 'Ab':'G#', 'Bb':'A#' | |
} | |
shift_major_dic = { | |
'C': 0, 'C#': 1, 'D': 2, 'D#': 3, 'E': 4, 'F': 5, | |
'F#': 6, 'G': 7, 'G#': 8, 'A': 9, 'A#': 10, 'B': 11 | |
} | |
shift_minor_dic = { | |
'A': 0, 'A#': 1, 'B': 2, 'C': 3, 'C#': 4, 'D': 5, | |
'D#': 6, 'E': 7, 'F': 8, 'F#': 9, 'G': 10, 'G#': 11, | |
} | |
flat_to_sharp_mapping = { | |
"Cb": "B", | |
"Db": "C#", | |
"Eb": "D#", | |
"Fb": "E", | |
"Gb": "F#", | |
"Ab": "G#", | |
"Bb": "A#" | |
} | |
segment_duration = 30 | |
resample_rate = 24000 | |
is_split = True | |
def normalize_chord(file_path, key, key_type='major'): | |
with open(file_path, 'r') as f: | |
lines = f.readlines() | |
if key == "None": | |
new_key = "C major" | |
shift = 0 | |
else: | |
if len(key) == 1: | |
key = key[0].upper() | |
else: | |
key = key[0].upper() + key[1:] | |
if key in minor_major_dic2: | |
key = minor_major_dic2[key] | |
shift = 0 | |
if key_type == "major": | |
new_key = "C major" | |
shift = shift_major_dic[key] | |
else: | |
new_key = "A minor" | |
shift = shift_minor_dic[key] | |
converted_lines = [] | |
for line in lines: | |
if line.strip(): | |
parts = line.split() | |
start_time = parts[0] | |
end_time = parts[1] | |
chord = parts[2] | |
if chord == "N" or chord == "X": | |
newchordnorm = chord | |
elif ":" in chord: | |
pitch = chord.split(":")[0] | |
attr = chord.split(":")[1] | |
pnum = pitch_num_dic[pitch] | |
new_idx = (pnum - shift) % 12 | |
newchord = PITCH_CLASS[new_idx] | |
newchordnorm = newchord + ":" + attr | |
else: | |
pitch = chord | |
pnum = pitch_num_dic[pitch] | |
new_idx = (pnum - shift) % 12 | |
newchord = PITCH_CLASS[new_idx] | |
newchordnorm = newchord | |
converted_lines.append(f"{start_time} {end_time} {newchordnorm}\n") | |
return converted_lines | |
def sanitize_key_signature(key): | |
return key.replace('-', 'b') | |
def resample_waveform(waveform, original_sample_rate, target_sample_rate): | |
if original_sample_rate != target_sample_rate: | |
resampler = T.Resample(original_sample_rate, target_sample_rate) | |
return resampler(waveform), target_sample_rate | |
return waveform, original_sample_rate | |
def split_audio(waveform, sample_rate): | |
segment_samples = segment_duration * sample_rate | |
total_samples = waveform.size(0) | |
segments = [] | |
for start in range(0, total_samples, segment_samples): | |
end = start + segment_samples | |
if end <= total_samples: | |
segments.append(waveform[start:end]) | |
if len(segments) == 0: | |
segments.append(waveform) | |
return segments | |
def safe_remove_dir(directory): | |
directory = Path(directory) | |
if directory.exists(): | |
try: | |
shutil.rmtree(directory) | |
except Exception as e: | |
print(f"ディレクトリ {directory} の削除中にエラーが発生しました: {e}") | |
# 追加:YouTube URL から音声をダウンロードする関数 | |
def download_audio_from_youtube(url, output_dir="inference/input"): | |
import yt_dlp | |
os.makedirs(output_dir, exist_ok=True) | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'outtmpl': os.path.join(output_dir, 'tmp.%(ext)s'), | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'noplaylist': True, | |
'quiet': True, | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info = ydl.extract_info(url, download=True) | |
title = info.get('title', '不明なタイトル') | |
output_file = os.path.join(output_dir, 'tmp.mp3') | |
return output_file, title | |
# Music2emo クラス(既存コード) | |
class Music2emo: | |
def __init__(self, | |
name="amaai-lab/music2emo", | |
device="cuda:0", | |
cache_dir=None, | |
local_files_only=False): | |
model_weights = "saved_models/J_all.ckpt" | |
self.device = device | |
self.feature_extractor = FeatureExtractorMERT(model_name='m-a-p/MERT-v1-95M', device=self.device, sr=resample_rate) | |
self.model_weights = model_weights | |
self.music2emo_model = FeedforwardModelMTAttnCK( | |
input_size=768 * 2, | |
output_size_classification=56, | |
output_size_regression=2 | |
) | |
checkpoint = torch.load(self.model_weights, map_location=self.device, weights_only=False) | |
state_dict = {key.replace("model.", ""): value for key, value in checkpoint["state_dict"].items()} | |
model_keys = set(self.music2emo_model.state_dict().keys()) | |
filtered_state_dict = {key: value for key, value in state_dict.items() if key in model_keys} | |
self.music2emo_model.load_state_dict(filtered_state_dict) | |
self.music2emo_model.to(self.device) | |
self.music2emo_model.eval() | |
self.config = HParams.load("./inference/data/run_config.yaml") | |
self.config.feature['large_voca'] = True | |
self.config.model['num_chords'] = 170 | |
model_file = './inference/data/btc_model_large_voca.pt' | |
self.idx_to_voca = idx2voca_chord() | |
self.btc_model = BTC_model(config=self.config.model).to(self.device) | |
if os.path.isfile(model_file): | |
checkpoint = torch.load(model_file, map_location=self.device) | |
self.mean = checkpoint['mean'] | |
self.std = checkpoint['std'] | |
self.btc_model.load_state_dict(checkpoint['model']) | |
self.tonic_to_idx = {tonic: idx for idx, tonic in enumerate(tonic_signatures)} | |
self.mode_to_idx = {mode: idx for idx, mode in enumerate(mode_signatures)} | |
self.idx_to_tonic = {idx: tonic for tonic, idx in self.tonic_to_idx.items()} | |
self.idx_to_mode = {idx: mode for mode, idx in self.mode_to_idx.items()} | |
with open('inference/data/chord.json', 'r') as f: | |
self.chord_to_idx = json.load(f) | |
with open('inference/data/chord_inv.json', 'r') as f: | |
self.idx_to_chord = {int(k): v for k, v in json.load(f).items()} | |
with open('inference/data/chord_root.json') as json_file: | |
self.chordRootDic = json.load(json_file) | |
with open('inference/data/chord_attr.json') as json_file: | |
self.chordAttrDic = json.load(json_file) | |
def predict(self, audio, threshold=0.5): | |
feature_dir = Path("./inference/temp_out") | |
output_dir = Path("./inference/output") | |
safe_remove_dir(feature_dir) | |
safe_remove_dir(output_dir) | |
feature_dir.mkdir(parents=True, exist_ok=True) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
warnings.filterwarnings('ignore') | |
logger.logging_verbosity(1) | |
mert_dir = feature_dir / "mert" | |
mert_dir.mkdir(parents=True, exist_ok=True) | |
waveform, sample_rate = torchaudio.load(audio) | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0).unsqueeze(0) | |
waveform = waveform.squeeze() | |
waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate) | |
if is_split: | |
segments = split_audio(waveform, sample_rate) | |
for i, segment in enumerate(segments): | |
segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy") | |
self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path) | |
else: | |
segment_save_path = os.path.join(mert_dir, f"segment_0.npy") | |
self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path) | |
segment_embeddings = [] | |
layers_to_extract = [5,6] | |
for filename in sorted(os.listdir(mert_dir)): | |
file_path = os.path.join(mert_dir, filename) | |
if os.path.isfile(file_path) and filename.endswith('.npy'): | |
segment = np.load(file_path) | |
concatenated_features = np.concatenate( | |
[segment[:, layer_idx, :] for layer_idx in layers_to_extract], axis=1 | |
) | |
concatenated_features = np.squeeze(concatenated_features) | |
segment_embeddings.append(concatenated_features) | |
segment_embeddings = np.array(segment_embeddings) | |
if len(segment_embeddings) > 0: | |
final_embedding_mert = np.mean(segment_embeddings, axis=0) | |
else: | |
final_embedding_mert = np.zeros((1536,)) | |
final_embedding_mert = torch.from_numpy(final_embedding_mert).to(self.device) | |
audio_path = audio | |
audio_id = os.path.split(audio_path)[-1][:-4] | |
try: | |
feature, feature_per_second, song_length_second = audio_file_to_features(audio_path, self.config) | |
except: | |
logger.info("音声ファイルの読み込みに失敗しました : %s" % audio_path) | |
assert(False) | |
logger.info("音声ファイルの読み込みと特徴量計算に成功しました : %s" % audio_path) | |
feature = feature.T | |
feature = (feature - self.mean) / self.std | |
time_unit = feature_per_second | |
n_timestep = self.config.model['timestep'] | |
num_pad = n_timestep - (feature.shape[0] % n_timestep) | |
feature = np.pad(feature, ((0, num_pad), (0, 0)), mode="constant", constant_values=0) | |
num_instance = feature.shape[0] // n_timestep | |
start_time = 0.0 | |
lines = [] | |
with torch.no_grad(): | |
self.btc_model.eval() | |
feature = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).to(self.device) | |
for t in range(num_instance): | |
self_attn_output, _ = self.btc_model.self_attn_layers(feature[:, n_timestep * t:n_timestep * (t + 1), :]) | |
prediction, _ = self.btc_model.output_layer(self_attn_output) | |
prediction = prediction.squeeze() | |
for i in range(n_timestep): | |
if t == 0 and i == 0: | |
prev_chord = prediction[i].item() | |
continue | |
if prediction[i].item() != prev_chord: | |
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord])) | |
start_time = time_unit * (n_timestep * t + i) | |
prev_chord = prediction[i].item() | |
if t == num_instance - 1 and i + num_pad == n_timestep: | |
if start_time != time_unit * (n_timestep * t + i): | |
lines.append('%.3f %.3f %s\n' % (start_time, time_unit * (n_timestep * t + i), self.idx_to_voca[prev_chord])) | |
break | |
save_path = os.path.join(feature_dir, os.path.split(audio_path)[-1].replace('.mp3', '').replace('.wav', '') + '.lab') | |
with open(save_path, 'w') as f: | |
for line in lines: | |
f.write(line) | |
try: | |
midi_file = converter.parse(save_path.replace('.lab', '.midi')) | |
key_signature = str(midi_file.analyze('key')) | |
except Exception as e: | |
key_signature = "None" | |
key_parts = key_signature.split() | |
key_signature = sanitize_key_signature(key_parts[0]) | |
key_type = key_parts[1] if len(key_parts) > 1 else 'major' | |
converted_lines = normalize_chord(save_path, key_signature, key_type) | |
lab_norm_path = save_path[:-4] + "_norm.lab" | |
with open(lab_norm_path, 'w') as f: | |
f.writelines(converted_lines) | |
chords = [] | |
if not os.path.exists(lab_norm_path): | |
chords.append((float(0), float(0), "N")) | |
else: | |
with open(lab_norm_path, 'r') as file: | |
for line in file: | |
start, end, chord = line.strip().split() | |
chords.append((float(start), float(end), chord)) | |
encoded = [] | |
encoded_root = [] | |
encoded_attr = [] | |
durations = [] | |
for start, end, chord in chords: | |
chord_arr = chord.split(":") | |
if len(chord_arr) == 1: | |
chordRootID = self.chordRootDic[chord_arr[0]] | |
chordAttrID = 0 if chord_arr[0] in ["N", "X"] else 1 | |
elif len(chord_arr) == 2: | |
chordRootID = self.chordRootDic[chord_arr[0]] | |
chordAttrID = self.chordAttrDic[chord_arr[1]] | |
encoded_root.append(chordRootID) | |
encoded_attr.append(chordAttrID) | |
if chord in self.chord_to_idx: | |
encoded.append(self.chord_to_idx[chord]) | |
else: | |
print(f"警告: {chord} は chord.json に見つかりませんでした。スキップします。") | |
durations.append(end - start) | |
encoded_chords = np.array(encoded) | |
encoded_chords_root = np.array(encoded_root) | |
encoded_chords_attr = np.array(encoded_attr) | |
max_sequence_length = 100 | |
if len(encoded_chords) > max_sequence_length: | |
encoded_chords = encoded_chords[:max_sequence_length] | |
encoded_chords_root = encoded_chords_root[:max_sequence_length] | |
encoded_chords_attr = encoded_chords_attr[:max_sequence_length] | |
else: | |
padding = [0] * (max_sequence_length - len(encoded_chords)) | |
encoded_chords = np.concatenate([encoded_chords, padding]) | |
encoded_chords_root = np.concatenate([encoded_chords_root, padding]) | |
encoded_chords_attr = np.concatenate([encoded_chords_attr, padding]) | |
chords_tensor = torch.tensor(encoded_chords, dtype=torch.long).to(self.device) | |
chords_root_tensor = torch.tensor(encoded_chords_root, dtype=torch.long).to(self.device) | |
chords_attr_tensor = torch.tensor(encoded_chords_attr, dtype=torch.long).to(self.device) | |
model_input_dic = { | |
"x_mert": final_embedding_mert.unsqueeze(0), | |
"x_chord": chords_tensor.unsqueeze(0), | |
"x_chord_root": chords_root_tensor.unsqueeze(0), | |
"x_chord_attr": chords_attr_tensor.unsqueeze(0), | |
"x_key": torch.tensor([self.mode_to_idx.get(key_type, 0)], dtype=torch.long).unsqueeze(0).to(self.device) | |
} | |
model_input_dic = {k: v.to(self.device) for k, v in model_input_dic.items()} | |
classification_output, regression_output = self.music2emo_model(model_input_dic) | |
tag_list = np.load("./inference/data/tag_list.npy") | |
tag_list = tag_list[127:] | |
mood_list = [t.replace("mood/theme---", "") for t in tag_list] | |
probs = torch.sigmoid(classification_output).squeeze().tolist() | |
predicted_moods_with_scores = [ | |
{"mood": mood_list[i], "score": round(p, 4)} | |
for i, p in enumerate(probs) if p > threshold | |
] | |
predicted_moods_with_scores_all = [ | |
{"mood": mood_list[i], "score": round(p, 4)} | |
for i, p in enumerate(probs) | |
] | |
predicted_moods_with_scores.sort(key=lambda x: x["score"], reverse=True) | |
valence, arousal = regression_output.squeeze().tolist() | |
model_output_dic = { | |
"valence": valence, | |
"arousal": arousal, | |
"predicted_moods": predicted_moods_with_scores, | |
"predicted_moods_all": predicted_moods_with_scores_all | |
} | |
return model_output_dic | |
# Music2Emo モデルの初期化 | |
if torch.cuda.is_available(): | |
music2emo = Music2emo() | |
else: | |
music2emo = Music2emo(device="cpu") | |
# 入力(音声ファイルまたはYouTube URL)を処理する関数 | |
def process_input(audio, youtube_url, threshold): | |
if youtube_url and youtube_url.strip().startswith("http"): | |
# YouTube URL が入力されている場合、音声をダウンロード | |
audio_file, video_title = download_audio_from_youtube(youtube_url) | |
output_dic = music2emo.predict(audio_file, threshold) | |
output_text, va_chart, mood_chart = format_prediction(output_dic) | |
output_text += f"\n動画タイトル: {video_title}" | |
return output_text, va_chart, mood_chart | |
elif audio: | |
output_dic = music2emo.predict(audio, threshold) | |
return format_prediction(output_dic) | |
else: | |
return "音声ファイルまたは YouTube URL を入力してください。", None, None | |
# 解析結果のフォーマット関数 | |
def format_prediction(model_output_dic): | |
valence = model_output_dic["valence"] | |
arousal = model_output_dic["arousal"] | |
predicted_moods_with_scores = model_output_dic["predicted_moods"] | |
predicted_moods_with_scores_all = model_output_dic["predicted_moods_all"] | |
va_chart = plot_valence_arousal(valence, arousal) | |
mood_chart = plot_mood_probabilities(predicted_moods_with_scores_all) | |
if predicted_moods_with_scores: | |
moods_text = ", ".join([f"{m['mood']} ({m['score']:.2f})" for m in predicted_moods_with_scores]) | |
else: | |
moods_text = "顕著なムードは検出されませんでした。" | |
output_text = f"""🎭 ムードタグ: {moods_text} | |
💖 バレンス: {valence:.2f} (1〜9 スケール) | |
⚡ アラウザル: {arousal:.2f} (1〜9 スケール)""" | |
return output_text, va_chart, mood_chart | |
def plot_mood_probabilities(predicted_moods_with_scores): | |
if not predicted_moods_with_scores: | |
return None | |
moods = [m["mood"] for m in predicted_moods_with_scores] | |
probs = [m["score"] for m in predicted_moods_with_scores] | |
sorted_indices = np.argsort(probs)[::-1] | |
sorted_probs = [probs[i] for i in sorted_indices] | |
sorted_moods = [moods[i] for i in sorted_indices] | |
fig, ax = plt.subplots(figsize=(8, 4)) | |
ax.barh(sorted_moods[:10], sorted_probs[:10], color="#4CAF50") | |
ax.set_xlabel("確率") | |
ax.set_title("上位10のムードタグ") | |
ax.invert_yaxis() | |
return fig | |
def plot_valence_arousal(valence, arousal): | |
fig, ax = plt.subplots(figsize=(4, 4)) | |
ax.scatter(valence, arousal, color="red", s=100) | |
ax.set_xlim(1, 9) | |
ax.set_ylim(1, 9) | |
ax.axhline(y=5, color='gray', linestyle='--', linewidth=1) | |
ax.axvline(x=5, color='gray', linestyle='--', linewidth=1) | |
ax.set_xlabel("バレンス (ポジティブ度)") | |
ax.set_ylabel("アラウザル (活発度)") | |
ax.set_title("バレンス・アラウザル プロット") | |
ax.grid(True, linestyle="--", alpha=0.6) | |
return fig | |
# Gradio UI の設定 | |
title = "🎵 Music2Emo:統一型音楽感情認識システム" | |
description_text = """ | |
<p> | |
音声ファイルまたは YouTube の URL を入力すると、Music2Emo が楽曲の感情的特徴を解析します。<br/><br/> | |
このデモでは、1) ムードタグ、2) バレンス(1〜9 スケール)、3) アラウザル(1〜9 スケール)を予測します。<br/><br/> | |
詳細は <a href="https://arxiv.org/abs/2502.03979" target="_blank">論文</a> をご参照ください。 | |
</p> | |
""" | |
css = """ | |
.gradio-container { | |
font-family: 'Inter', -apple-system, system-ui, sans-serif; | |
} | |
.gr-button { | |
color: white; | |
background: #4CAF50; | |
border-radius: 8px; | |
padding: 10px; | |
} | |
.gr-box { | |
padding-top: 25px !important; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(f"<h1 style='text-align: center;'>{title}</h1>") | |
gr.Markdown(description_text) | |
gr.Markdown(""" | |
### 📝 注意事項: | |
- **対応音声フォーマット:** MP3, WAV | |
- **YouTube URL も入力可能です(任意) | |
- **推奨:** 高品質な音声ファイル | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_audio = gr.Audio(label="音声ファイルをアップロード", type="filepath") | |
youtube_url = gr.Textbox(label="YouTube URL (任意)", placeholder="例: https://youtu.be/XXXXXXX") | |
threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="ムード検出のしきい値", info="しきい値を調整してください") | |
predict_btn = gr.Button("🎭 感情解析を実行", variant="primary") | |
with gr.Column(scale=1): | |
output_text = gr.Textbox(label="解析結果", lines=4, interactive=False) | |
with gr.Row(equal_height=True): | |
mood_chart = gr.Plot(label="ムード確率", scale=2, elem_classes=["gr-box"]) | |
va_chart = gr.Plot(label="バレンス・アラウザル", scale=1, elem_classes=["gr-box"]) | |
predict_btn.click( | |
fn=process_input, | |
inputs=[input_audio, youtube_url, threshold], | |
outputs=[output_text, va_chart, mood_chart] | |
) | |
demo.queue().launch() | |