Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
import gradio as gr | |
import zipfile | |
import json | |
import requests | |
import subprocess | |
import shutil | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
title = "# 🗜️ CLaMP 3 - Multimodal & Multilingual Semantic Music Search" | |
badges = """ | |
<div style="text-align: center;"> | |
<a href="https://sanderwood.github.io/clamp3/"> | |
<img src="https://img.shields.io/badge/CLaMP%203%20Homepage-GitHub-181717?style=for-the-badge&logo=home-assistant" alt="Homepage"> | |
</a> | |
<a href="https://arxiv.org/abs/2502.10362"> | |
<img src="https://img.shields.io/badge/CLaMP%203%20Paper-Arxiv-red?style=for-the-badge&logo=arxiv" alt="Paper"> | |
</a> | |
<a href="https://github.com/sanderwood/clamp3"> | |
<img src="https://img.shields.io/badge/CLaMP%203%20Code-GitHub-181717?style=for-the-badge&logo=github" alt="GitHub"> | |
</a> | |
<a href="https://huggingface.co/spaces/sander-wood/clamp3"> | |
<img src="https://img.shields.io/badge/CLaMP%203%20Demo-Gradio-green?style=for-the-badge&logo=gradio" alt="Demo"> | |
</a> | |
<a href="https://huggingface.co/sander-wood/clamp3/tree/main"> | |
<img src="https://img.shields.io/badge/Model%20Weights-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Model Weights"> | |
</a> | |
<a href="https://huggingface.co/datasets/sander-wood/m4-rag"> | |
<img src="https://img.shields.io/badge/M4--RAG%20Dataset-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Dataset"> | |
</a> | |
<a href="https://huggingface.co/datasets/sander-wood/wikimt-x"> | |
<img src="https://img.shields.io/badge/WikiMT--X%20Benchmark-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Benchmark"> | |
</a> | |
</div> | |
<style> | |
div a { | |
display: inline-block; | |
margin: 5px; | |
} | |
div a img { | |
height: 30px; | |
} | |
</style> | |
""" | |
description = """CLaMP 3 is a **multimodal and multilingual** music information retrieval (MIR) framework, supporting **sheet music, audio, and performance signals** in **100 languages**. Using **contrastive learning**, it aligns these modalities in a shared space for **cross-modal retrieval**. | |
### 🔍 **How This Demo Works** | |
- You can **retrieve music using any text input (in any language) or an image** (`.png`, `.jpg`). | |
- When using an image, **BLIP** generates a caption, which is then used for retrieval. | |
- Since CLaMP 3's training data includes **rich visual descriptions of musical scenes**, it can **match images to semantically relevant music**. | |
### ⚠️ **Limitations** | |
- This demo retrieves music **only from the WikiMT-X benchmark (1,000 pieces)**. | |
- These pieces are **mainly from the U.S. and Western Europe (especially the U.S.)** and **mostly from the 20th century**. | |
- Thus, retrieval results are **mostly limited to Western 20th-century music**, so you **won’t** find music from **other regions or historical periods**. | |
🔧 **Need retrieval for a different music collection?** Deploy **[CLaMP 3](https://github.com/sanderwood/clamp3)** on your own dataset. | |
Generally, the larger and more diverse the reference music dataset, the better the retrieval quality, increasing the likelihood of finding relevant and accurately matched music. | |
**Note: This project is for research use only.** | |
""" | |
# Load BLIP image captioning model and processor | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
# Download weight file if it does not exist | |
weights_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" | |
weights_filename = "weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth" | |
if not os.path.exists(weights_filename): | |
print("Downloading weights file...") | |
response = requests.get(weights_url, stream=True) | |
response.raise_for_status() | |
with open(weights_filename, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
print("Weights file downloaded.") | |
ZIP_PATH = "features.zip" | |
if os.path.exists(ZIP_PATH): | |
print(f"Extracting {ZIP_PATH}...") | |
with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref: | |
zip_ref.extractall(".") | |
print("Extraction complete.") | |
# Load metadata | |
metadata_map = {} | |
METADATA_FILE = "wikimt-x-public.jsonl" | |
if os.path.exists(METADATA_FILE): | |
with open(METADATA_FILE, "r", encoding="utf-8") as f: | |
for line in f: | |
data = json.loads(line) | |
metadata_map[data["id"]] = data | |
else: | |
print(f"Warning: {METADATA_FILE} not found.") | |
features_cache = {} | |
def get_info(folder_path): | |
""" | |
Load all .npy files from the specified folder and return a dictionary | |
with the file names (without extension) as keys. | |
""" | |
if folder_path in features_cache: | |
return features_cache[folder_path] | |
if not os.path.exists(folder_path): | |
return {} | |
files = sorted(os.listdir(folder_path)) | |
features = {} | |
for file in files: | |
if file.endswith(".npy"): | |
key = file.split(".")[0] | |
try: | |
features[key] = np.load(os.path.join(folder_path, file))[0] | |
except Exception as e: | |
print(f"Error loading {file}: {e}") | |
features_cache[folder_path] = features | |
return features | |
def find_top_similar(query_file, reference_folder): | |
""" | |
Compare the query feature with all reference features in the specified folder | |
using cosine similarity and return the top 10 candidate results in the format: | |
Title | Artists | sim: SimilarityScore. | |
""" | |
top_k = 10 | |
try: | |
query_feature = np.load(query_file.name)[0] | |
except Exception as e: | |
return [], f"Error loading query feature: {e}" | |
query_tensor = torch.tensor(query_feature, dtype=torch.float32).unsqueeze(dim=0) | |
key_features = get_info(reference_folder) | |
if not key_features: | |
return [], f"No reference features found in {reference_folder}." | |
ref_keys = list(key_features.keys()) | |
ref_array = np.array([key_features[k] for k in ref_keys]) | |
key_feats_tensor = torch.tensor(ref_array, dtype=torch.float32) | |
query_tensor_expanded = query_tensor.expand(key_feats_tensor.size(0), -1) | |
similarities = torch.cosine_similarity(query_tensor_expanded, key_feats_tensor, dim=1) | |
ranked_indices = torch.argsort(similarities, descending=True) | |
candidate_ids = [] | |
candidate_display = [] | |
for i in range(top_k): | |
if i < len(ref_keys): | |
candidate_idx = ranked_indices[i].item() | |
candidate_id = ref_keys[candidate_idx] | |
sim = round(similarities[candidate_idx].item(), 4) | |
meta = metadata_map.get(candidate_id, {}) | |
title = meta.get("title", candidate_id) | |
artists = meta.get("artists", "Unknown") | |
if isinstance(artists, list): | |
artists = ", ".join(artists) | |
candidate_ids.append(candidate_id) | |
candidate_display.append(f"{title} | {artists} | sim: {sim}") | |
else: | |
candidate_ids.append("N/A") | |
candidate_display.append("N/A") | |
return candidate_ids, candidate_display | |
def show_details(selected_id): | |
""" | |
Return detailed metadata and embedded YouTube video HTML based on the candidate ID. | |
""" | |
if selected_id == "N/A": | |
return ("", "", "", "", "", "", "", "") | |
data = metadata_map.get(selected_id, {}) | |
if not data: | |
return ("No details found", "", "", "", "", "", "", "") | |
title = data.get("title", "") | |
artists = data.get("artists", "") | |
if isinstance(artists, list): | |
artists = ", ".join(artists) | |
genre = data.get("genre", "") | |
background = data.get("background", "") | |
analysis = data.get("analysis", "") | |
description = data.get("description", "") | |
scene = data.get("scene", "") | |
youtube_html = ( | |
f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{selected_id}" ' | |
f'frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; ' | |
f'gyroscope; picture-in-picture" allowfullscreen></iframe>' | |
) | |
return title, artists, genre, background, analysis, description, scene, youtube_html | |
def extract_features_from_text(text): | |
""" | |
Save the input text to a file, call the CLaMP 3 feature extraction script, | |
and return the generated feature file path. | |
""" | |
input_dir = "input_dir" | |
output_dir = "output_dir" | |
os.makedirs(input_dir, exist_ok=True) | |
os.makedirs(output_dir, exist_ok=True) | |
# Clear input_dir and output_dir | |
for d in [input_dir, output_dir]: | |
for filename in os.listdir(d): | |
file_path = os.path.join(d, filename) | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
input_file = os.path.join(input_dir, "input.txt") | |
print("Text input:", text) | |
with open(input_file, "w", encoding="utf-8") as f: | |
f.write(text) | |
command = ["python", "extract_clamp3.py", input_dir, output_dir, "--get_global"] | |
subprocess.run(command, check=True) | |
output_file = os.path.join(output_dir, "input.npy") | |
return output_file | |
def generate_caption(image): | |
""" | |
Use the BLIP model to generate a descriptive caption for the given image. | |
""" | |
inputs = processor(image, return_tensors="pt") | |
outputs = blip_model.generate(**inputs) | |
caption = processor.decode(outputs[0], skip_special_tokens=True) | |
return caption | |
class FileWrapper: | |
""" | |
Simulate a file object with a .name attribute. | |
""" | |
def __init__(self, path): | |
self.name = path | |
def search_wrapper(search_mode, text_input, image_input): | |
""" | |
Perform retrieval based on the selected input mode: | |
- If search_mode is "Image", use the uploaded image to generate a caption, then extract features | |
and search in the "image/" folder. | |
- If search_mode is "Text", use the provided text to extract features and search in the "image/" folder. | |
""" | |
if search_mode == "Image": | |
if image_input is None: | |
return text_input, gr.update(choices=[]), "Please upload an image.", "", "", "", "", "", "", "" | |
caption = generate_caption(image_input) | |
text_to_use = caption | |
reference_folder = "image/" | |
elif search_mode == "Text": | |
if not text_input or text_input.strip() == "": | |
return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Please enter text for retrieval.", "", "", "", "", "", "", "" | |
text_to_use = text_input | |
reference_folder = "text/" | |
else: | |
return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Invalid search mode selected.", "", "", "", "", "", "", "" | |
try: | |
output_file = extract_features_from_text(text_to_use) | |
query_file = FileWrapper(output_file) | |
except Exception as e: | |
return text_to_use, gr.update(choices=[]), f"Error during feature extraction: {e}", "", "", "", "", "", "", "" | |
candidate_ids, candidate_display = find_top_similar(query_file, reference_folder) | |
if not candidate_ids: | |
return text_to_use, gr.update(choices=[]), "", "", "", "", "", "", "", "" | |
choices = [(f"{i+1}. {disp}", cid) for i, (cid, disp) in enumerate(zip(candidate_ids, candidate_display))] | |
top_candidate = candidate_ids[0] | |
details = show_details(top_candidate) | |
return text_to_use, gr.update(choices=choices), *details | |
# 定义示例数据(示例数据放在组件定义之后也可以正常运行) | |
examples = [ | |
["Image", None, "V4EauuhVEw4.jpg"], | |
["Image", None, "Kw-_Ew5bVxs.jpg"], | |
["Image", None, "BuYf0taXoNw.webp"], | |
["Image", None, "4tDYMayp6Dk.jpg"], | |
["Text", "classic rock, British, 1960s, upbeat", None], | |
["Text", "A Latin jazz piece with rhythmic percussion and brass", None], | |
["Text", "big band, major key, swing, brass-heavy, syncopation, baritone vocal", None], | |
["Text", "Heartfelt and nostalgic, with a bittersweet, melancholic feel", None], | |
["Text", "Melodía instrumental en re mayor con progresión armónica repetitiva y fluida", None], | |
["Text", "D大调四四拍的爱尔兰舞曲", None], | |
["Text", "Ιερή μουσική με πνευματική ατμόσφαιρα", None], | |
["Text", "የፍቅር ሙዚቃ ሞቅ እና ስሜታማ ከሆነ ነገር ግን ድንቅ እና አስደሳች ቃላት ያካትታል", None], | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.HTML(badges) | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(): | |
search_mode = gr.Radio( | |
choices=["Text", "Image"], | |
label="Select Search Mode", | |
value="Text", | |
interactive=True, | |
elem_classes=["vertical-radio"] | |
) | |
text_input = gr.Textbox( | |
placeholder="Describe the music you're looking for (in any language)", | |
lines=4 | |
) | |
image_input = gr.Image( | |
label="Or upload an image (PNG, JPG)", | |
type="pil" | |
) | |
search_button = gr.Button("Search") | |
candidate_radio = gr.Radio(choices=[], label="Select Retrieval Result", interactive=True, elem_classes=["vertical-radio"]) | |
with gr.Column(): | |
gr.Markdown("### YouTube Video") | |
youtube_box = gr.HTML(label="YouTube Video") | |
gr.Markdown("### Metadata") | |
title_box = gr.Textbox(label="Title", interactive=False) | |
artists_box = gr.Textbox(label="Artists", interactive=False) | |
genre_box = gr.Textbox(label="Genre", interactive=False) | |
background_box = gr.Textbox(label="Background", interactive=False) | |
analysis_box = gr.Textbox(label="Analysis", interactive=False) | |
description_box = gr.Textbox(label="Description", interactive=False) | |
scene_box = gr.Textbox(label="Scene", interactive=False) | |
gr.HTML( | |
""" | |
<style> | |
.vertical-radio .gradio-radio label { | |
display: block !important; | |
margin-bottom: 5px; | |
} | |
</style> | |
""" | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[search_mode, text_input, image_input], | |
outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box], | |
fn=search_wrapper, | |
cache_examples=False, | |
) | |
search_button.click( | |
fn=search_wrapper, | |
inputs=[search_mode, text_input, image_input], | |
outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box] | |
) | |
candidate_radio.change( | |
fn=show_details, | |
inputs=candidate_radio, | |
outputs=[title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box] | |
) | |
demo.launch() | |