clamp3 / app.py
sander-wood's picture
Upload 12 files
b5c9b89 verified
raw
history blame
15.8 kB
import os
import torch
import numpy as np
import gradio as gr
import zipfile
import json
import requests
import subprocess
import shutil
from transformers import BlipProcessor, BlipForConditionalGeneration
title = "# 🗜️ CLaMP 3 - Multimodal & Multilingual Semantic Music Search"
badges = """
<div style="text-align: center;">
<a href="https://sanderwood.github.io/clamp3/">
<img src="https://img.shields.io/badge/CLaMP%203%20Homepage-GitHub-181717?style=for-the-badge&logo=home-assistant" alt="Homepage">
</a>
<a href="https://arxiv.org/abs/2502.10362">
<img src="https://img.shields.io/badge/CLaMP%203%20Paper-Arxiv-red?style=for-the-badge&logo=arxiv" alt="Paper">
</a>
<a href="https://github.com/sanderwood/clamp3">
<img src="https://img.shields.io/badge/CLaMP%203%20Code-GitHub-181717?style=for-the-badge&logo=github" alt="GitHub">
</a>
<a href="https://huggingface.co/spaces/sander-wood/clamp3">
<img src="https://img.shields.io/badge/CLaMP%203%20Demo-Gradio-green?style=for-the-badge&logo=gradio" alt="Demo">
</a>
<a href="https://huggingface.co/sander-wood/clamp3/tree/main">
<img src="https://img.shields.io/badge/Model%20Weights-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Model Weights">
</a>
<a href="https://huggingface.co/datasets/sander-wood/m4-rag">
<img src="https://img.shields.io/badge/M4--RAG%20Dataset-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Dataset">
</a>
<a href="https://huggingface.co/datasets/sander-wood/wikimt-x">
<img src="https://img.shields.io/badge/WikiMT--X%20Benchmark-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Benchmark">
</a>
</div>
<style>
div a {
display: inline-block;
margin: 5px;
}
div a img {
height: 30px;
}
</style>
"""
description = """CLaMP 3 is a **multimodal and multilingual** music information retrieval (MIR) framework, supporting **sheet music, audio, and performance signals** in **100 languages**. Using **contrastive learning**, it aligns these modalities in a shared space for **cross-modal retrieval**.
### 🔍 **How This Demo Works**
- You can **retrieve music using any text input (in any language) or an image** (`.png`, `.jpg`).
- When using an image, **BLIP** generates a caption, which is then used for retrieval.
- Since CLaMP 3's training data includes **rich visual descriptions of musical scenes**, it can **match images to semantically relevant music**.
### ⚠️ **Limitations**
- This demo retrieves music **only from the WikiMT-X benchmark (1,000 pieces)**.
- These pieces are **mainly from the U.S. and Western Europe (especially the U.S.)** and **mostly from the 20th century**.
- Thus, retrieval results are **mostly limited to Western 20th-century music**, so you **won’t** find music from **other regions or historical periods**.
🔧 **Need retrieval for a different music collection?** Deploy **[CLaMP 3](https://github.com/sanderwood/clamp3)** on your own dataset.
Generally, the larger and more diverse the reference music dataset, the better the retrieval quality, increasing the likelihood of finding relevant and accurately matched music.
**Note: This project is for research use only.**
"""
# Load BLIP image captioning model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Download weight file if it does not exist
weights_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
weights_filename = "weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
if not os.path.exists(weights_filename):
print("Downloading weights file...")
response = requests.get(weights_url, stream=True)
response.raise_for_status()
with open(weights_filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("Weights file downloaded.")
ZIP_PATH = "features.zip"
if os.path.exists(ZIP_PATH):
print(f"Extracting {ZIP_PATH}...")
with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
zip_ref.extractall(".")
print("Extraction complete.")
# Load metadata
metadata_map = {}
METADATA_FILE = "wikimt-x-public.jsonl"
if os.path.exists(METADATA_FILE):
with open(METADATA_FILE, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
metadata_map[data["id"]] = data
else:
print(f"Warning: {METADATA_FILE} not found.")
features_cache = {}
def get_info(folder_path):
"""
Load all .npy files from the specified folder and return a dictionary
with the file names (without extension) as keys.
"""
if folder_path in features_cache:
return features_cache[folder_path]
if not os.path.exists(folder_path):
return {}
files = sorted(os.listdir(folder_path))
features = {}
for file in files:
if file.endswith(".npy"):
key = file.split(".")[0]
try:
features[key] = np.load(os.path.join(folder_path, file))[0]
except Exception as e:
print(f"Error loading {file}: {e}")
features_cache[folder_path] = features
return features
def find_top_similar(query_file, reference_folder):
"""
Compare the query feature with all reference features in the specified folder
using cosine similarity and return the top 10 candidate results in the format:
Title | Artists | sim: SimilarityScore.
"""
top_k = 10
try:
query_feature = np.load(query_file.name)[0]
except Exception as e:
return [], f"Error loading query feature: {e}"
query_tensor = torch.tensor(query_feature, dtype=torch.float32).unsqueeze(dim=0)
key_features = get_info(reference_folder)
if not key_features:
return [], f"No reference features found in {reference_folder}."
ref_keys = list(key_features.keys())
ref_array = np.array([key_features[k] for k in ref_keys])
key_feats_tensor = torch.tensor(ref_array, dtype=torch.float32)
query_tensor_expanded = query_tensor.expand(key_feats_tensor.size(0), -1)
similarities = torch.cosine_similarity(query_tensor_expanded, key_feats_tensor, dim=1)
ranked_indices = torch.argsort(similarities, descending=True)
candidate_ids = []
candidate_display = []
for i in range(top_k):
if i < len(ref_keys):
candidate_idx = ranked_indices[i].item()
candidate_id = ref_keys[candidate_idx]
sim = round(similarities[candidate_idx].item(), 4)
meta = metadata_map.get(candidate_id, {})
title = meta.get("title", candidate_id)
artists = meta.get("artists", "Unknown")
if isinstance(artists, list):
artists = ", ".join(artists)
candidate_ids.append(candidate_id)
candidate_display.append(f"{title} | {artists} | sim: {sim}")
else:
candidate_ids.append("N/A")
candidate_display.append("N/A")
return candidate_ids, candidate_display
def show_details(selected_id):
"""
Return detailed metadata and embedded YouTube video HTML based on the candidate ID.
"""
if selected_id == "N/A":
return ("", "", "", "", "", "", "", "")
data = metadata_map.get(selected_id, {})
if not data:
return ("No details found", "", "", "", "", "", "", "")
title = data.get("title", "")
artists = data.get("artists", "")
if isinstance(artists, list):
artists = ", ".join(artists)
genre = data.get("genre", "")
background = data.get("background", "")
analysis = data.get("analysis", "")
description = data.get("description", "")
scene = data.get("scene", "")
youtube_html = (
f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{selected_id}" '
f'frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; '
f'gyroscope; picture-in-picture" allowfullscreen></iframe>'
)
return title, artists, genre, background, analysis, description, scene, youtube_html
def extract_features_from_text(text):
"""
Save the input text to a file, call the CLaMP 3 feature extraction script,
and return the generated feature file path.
"""
input_dir = "input_dir"
output_dir = "output_dir"
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
# Clear input_dir and output_dir
for d in [input_dir, output_dir]:
for filename in os.listdir(d):
file_path = os.path.join(d, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
input_file = os.path.join(input_dir, "input.txt")
print("Text input:", text)
with open(input_file, "w", encoding="utf-8") as f:
f.write(text)
command = ["python", "extract_clamp3.py", input_dir, output_dir, "--get_global"]
subprocess.run(command, check=True)
output_file = os.path.join(output_dir, "input.npy")
return output_file
def generate_caption(image):
"""
Use the BLIP model to generate a descriptive caption for the given image.
"""
inputs = processor(image, return_tensors="pt")
outputs = blip_model.generate(**inputs)
caption = processor.decode(outputs[0], skip_special_tokens=True)
return caption
class FileWrapper:
"""
Simulate a file object with a .name attribute.
"""
def __init__(self, path):
self.name = path
def search_wrapper(search_mode, text_input, image_input):
"""
Perform retrieval based on the selected input mode:
- If search_mode is "Image", use the uploaded image to generate a caption, then extract features
and search in the "image/" folder.
- If search_mode is "Text", use the provided text to extract features and search in the "image/" folder.
"""
if search_mode == "Image":
if image_input is None:
return text_input, gr.update(choices=[]), "Please upload an image.", "", "", "", "", "", "", ""
caption = generate_caption(image_input)
text_to_use = caption
reference_folder = "image/"
elif search_mode == "Text":
if not text_input or text_input.strip() == "":
return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Please enter text for retrieval.", "", "", "", "", "", "", ""
text_to_use = text_input
reference_folder = "text/"
else:
return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Invalid search mode selected.", "", "", "", "", "", "", ""
try:
output_file = extract_features_from_text(text_to_use)
query_file = FileWrapper(output_file)
except Exception as e:
return text_to_use, gr.update(choices=[]), f"Error during feature extraction: {e}", "", "", "", "", "", "", ""
candidate_ids, candidate_display = find_top_similar(query_file, reference_folder)
if not candidate_ids:
return text_to_use, gr.update(choices=[]), "", "", "", "", "", "", "", ""
choices = [(f"{i+1}. {disp}", cid) for i, (cid, disp) in enumerate(zip(candidate_ids, candidate_display))]
top_candidate = candidate_ids[0]
details = show_details(top_candidate)
return text_to_use, gr.update(choices=choices), *details
# 定义示例数据(示例数据放在组件定义之后也可以正常运行)
examples = [
["Image", None, "V4EauuhVEw4.jpg"],
["Image", None, "Kw-_Ew5bVxs.jpg"],
["Image", None, "BuYf0taXoNw.webp"],
["Image", None, "4tDYMayp6Dk.jpg"],
["Text", "classic rock, British, 1960s, upbeat", None],
["Text", "A Latin jazz piece with rhythmic percussion and brass", None],
["Text", "big band, major key, swing, brass-heavy, syncopation, baritone vocal", None],
["Text", "Heartfelt and nostalgic, with a bittersweet, melancholic feel", None],
["Text", "Melodía instrumental en re mayor con progresión armónica repetitiva y fluida", None],
["Text", "D大调四四拍的爱尔兰舞曲", None],
["Text", "Ιερή μουσική με πνευματική ατμόσφαιρα", None],
["Text", "የፍቅር ሙዚቃ ሞቅ እና ስሜታማ ከሆነ ነገር ግን ድንቅ እና አስደሳች ቃላት ያካትታል", None],
]
with gr.Blocks() as demo:
gr.Markdown(title)
gr.HTML(badges)
gr.Markdown(description)
with gr.Row():
with gr.Column():
search_mode = gr.Radio(
choices=["Text", "Image"],
label="Select Search Mode",
value="Text",
interactive=True,
elem_classes=["vertical-radio"]
)
text_input = gr.Textbox(
placeholder="Describe the music you're looking for (in any language)",
lines=4
)
image_input = gr.Image(
label="Or upload an image (PNG, JPG)",
type="pil"
)
search_button = gr.Button("Search")
candidate_radio = gr.Radio(choices=[], label="Select Retrieval Result", interactive=True, elem_classes=["vertical-radio"])
with gr.Column():
gr.Markdown("### YouTube Video")
youtube_box = gr.HTML(label="YouTube Video")
gr.Markdown("### Metadata")
title_box = gr.Textbox(label="Title", interactive=False)
artists_box = gr.Textbox(label="Artists", interactive=False)
genre_box = gr.Textbox(label="Genre", interactive=False)
background_box = gr.Textbox(label="Background", interactive=False)
analysis_box = gr.Textbox(label="Analysis", interactive=False)
description_box = gr.Textbox(label="Description", interactive=False)
scene_box = gr.Textbox(label="Scene", interactive=False)
gr.HTML(
"""
<style>
.vertical-radio .gradio-radio label {
display: block !important;
margin-bottom: 5px;
}
</style>
"""
)
gr.Examples(
examples=examples,
inputs=[search_mode, text_input, image_input],
outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box],
fn=search_wrapper,
cache_examples=False,
)
search_button.click(
fn=search_wrapper,
inputs=[search_mode, text_input, image_input],
outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
)
candidate_radio.change(
fn=show_details,
inputs=candidate_radio,
outputs=[title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
)
demo.launch()