File size: 6,160 Bytes
11872ca
 
0acfe63
11872ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4401dfb
 
 
 
11872ca
4401dfb
 
 
 
 
 
11872ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd47a89
4401dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11872ca
4401dfb
11872ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd1e276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11872ca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from dataclasses import dataclass, field
import logging
import spaces

import sys
sys.path.append("/home/user/app/src/sonicverse")

from huggingface_hub import login
import os

hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
if not hf_token:
    raise ValueError("Missing HUGGINGFACE_HUB_TOKEN. Set it as a secret in your Space.")

login(token=hf_token)

import gradio as gr
import torch
import transformers
import torchaudio

from openai import OpenAI
client = OpenAI()
MODEL       = "gpt-4"
SLEEP_BETWEEN_CALLS = 1.0

from sonicverse.model_utils import MultiTaskType
from sonicverse.training import ModelArguments
from sonicverse.inference import load_trained_lora_model
from sonicverse.data_tools import encode_chat

CHUNK_LENGTH = 10

@dataclass
class ServeArguments(ModelArguments):
    load_bits: int = field(default=16)
    max_new_tokens: int = field(default=128)
    temperature: float = field(default=0.01)


logging.getLogger().setLevel(logging.INFO)

parser = transformers.HfArgumentParser((ServeArguments,))
serve_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

model, tokenizer = load_trained_lora_model(
    model_name_or_path=serve_args.model_name_or_path,
    model_lora_path=serve_args.model_lora_path,
    load_bits=serve_args.load_bits,
    use_multi_task=MultiTaskType(serve_args.use_multi_task),
    tasks_config=serve_args.tasks_config
)

@spaces.GPU(duration=150)
def caption_audio(audio_file):
    chunk_audio_files = split_audio(audio_file, CHUNK_LENGTH)
    chunk_captions = []
    for audio_chunk in chunk_audio_files:
        chunk_captions.append(generate_caption(audio_chunk))

    if len(chunk_captions) > 1:
        audio_name = os.path.splitext(os.path.basename(audio_file))[0]
        long_caption = summarize_song(audio_name, chunk_captions)

        delete_files(chunk_audio_files)

        return long_caption

    else:
        if len(chunk_captions) == 1:
            return chunk_captions[0]
        else:
            return ""

def summarize_song(song_name, chunks):
    prompt = f"""
You are a music critic. Given the following chronological 10‑second chunk descriptions of a single piece, write one flowing, detailed description of the entire song—its structure, instrumentation, and standout moments. Mention transition points in terms of time stamps. If the description of certain chunks does not seem to fit with those for the chunks before and after, treat those as bad descriptions with lower accuracy and do not incorporate the information. Retain concrete musical attributes such as key, chords, tempo.

Chunks for “{song_name} ”:
"""
    for i, c in enumerate(chunks, 1):
        prompt += f"\n {(i - 1)*0} to {i*10} seconds. {c.strip()}"
    prompt += "\n\nFull song description:"

    resp = client.chat.completions.create(model=MODEL,
    messages=[
        {"role": "system", "content": "You are an expert music writer."},
        {"role": "user",   "content": prompt}
    ],
    temperature=0.0,
    max_tokens=1000)
    return resp.choices[0].message.content.strip()

def delete_files(file_paths):
    for path in file_paths:
        try:
            if os.path.isfile(path):
                os.remove(path)
                print(f"Deleted: {path}")
            else:
                print(f"Skipped (not a file or doesn't exist): {path}")
        except Exception as e:
            print(f"Error deleting {path}: {e}")

def split_audio(input_path, chunk_length_seconds):

    waveform, sample_rate = torchaudio.load(input_path)
    num_channels, total_samples = waveform.shape
    chunk_samples = int(chunk_length_seconds * sample_rate)

    num_chunks = (total_samples + chunk_samples - 1) // chunk_samples

    base, ext = os.path.splitext(input_path)
    output_paths = []

    if (num_chunks <= 1):
        return [input_path]

    for i in range(num_chunks):
        start = i * chunk_samples
        end = min((i + 1) * chunk_samples, total_samples)
        chunk_waveform = waveform[:, start:end]
        
        output_file = f"{base}_{i+1:03d}{ext}"
        torchaudio.save(output_file, chunk_waveform, sample_rate)
        output_paths.append(output_file)

    return output_paths

def generate_caption(audio_file):
    req_json = {
        "messages": [
            {"role": "user", "content": "Describe the music. <sound>"}
        ],
        "sounds": [audio_file]
    }

    encoded_dict = encode_chat(req_json, tokenizer, model.modalities)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids=encoded_dict["input_ids"].unsqueeze(0).to(model.device),
            max_new_tokens=serve_args.max_new_tokens,
            use_cache=True,
            do_sample=True,
            temperature=serve_args.temperature,
            modality_inputs={
                m.name: [encoded_dict[m.name]] for m in model.modalities
            },
        )

    outputs = tokenizer.decode(
        output_ids[0, encoded_dict["input_ids"].shape[0]:],
        skip_special_tokens=True
    ).strip()

    return outputs


with gr.Blocks(title="SonicVerse") as demo:
    gr.Markdown("""
        # 🎼 SonicVerse: Music Captioning Demo

        Welcome to **SonicVerse**, a multi-task music captioning model that provides natural language descriptions of input clips.

        🎵 Captions include music features such as:
        - Genre
        - Mood
        - Instrumentation
        - Vocals
        - Key

        📘 [Read the Paper](https://arxiv.org/abs/2506.15154)

         🖥️ [Replicate locally](https://github.com/amaai-lab/SonicVerse)


        ⚠️ **Note:** You can upload audio of any length, but due to compute limits on Hugging Face Spaces,
        it is recommended to keep clips under 30 seconds unless you have a Pro account or run this locally.
    """)

    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Upload your music clip")
        caption_output = gr.Textbox(label="Generated Caption", lines=8)

    submit_btn = gr.Button("Generate Caption")
    submit_btn.click(fn=caption_audio, inputs=audio_input, outputs=caption_output)



if __name__ == "__main__":
    demo.launch()