File size: 2,700 Bytes
0c6352e
0767df8
27ab8aa
 
0c6352e
 
 
 
 
 
 
 
 
 
27ab8aa
 
 
0c6352e
27ab8aa
15b3fd6
27ab8aa
0c6352e
27ab8aa
0c6352e
27ab8aa
 
 
 
 
0c6352e
27ab8aa
 
0c6352e
27ab8aa
 
 
 
 
0c6352e
c602888
 
 
 
 
 
 
 
 
 
27ab8aa
 
 
 
 
 
 
 
 
 
 
 
 
0c6352e
27ab8aa
 
 
 
0c6352e
29d9afa
 
f6cb59d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import torch
from diffusers import AudioLDMPipeline
from transformers import AutoProcessor, ClapModel

# make Space compatible with CPU duplicates
if torch.cuda.is_available():
    device = "cuda"
    torch_dtype = torch.float16
else:
    device = "cpu"
    torch_dtype = torch.float32

# load the diffusers pipeline
repo_id = "cvssp/audioldm-m-full"
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
pipe.unet = torch.compile(pipe.unet)

# CLAP model (only required for automatic scoring)
clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")

generator = torch.Generator(device)

# Streamlit app setup
st.set_page_config(
    page_title="Text to Music",
    page_icon="🎵",
)

text_input = st.text_input("Input text", "A hammer is hitting a wooden surface")
negative_prompt = st.text_input("Negative prompt", "low quality, average quality")

st.markdown("### Configuration")
seed = st.number_input("Seed", value=45)
duration = st.slider("Duration (seconds)", 2.5, 10.0, 5.0, 2.5)
guidance_scale = st.slider("Guidance scale", 0.0, 4.0, 2.5, 0.5)
n_candidates = st.slider("Number waveforms to generate", 1, 3, 3, 1)

def score_waveforms(text, waveforms):
    inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
    inputs = {key: inputs[key].to(device) for key in inputs}
    with torch.no_grad():
        logits_per_text = clap_model(**inputs).logits_per_text  # this is the audio-text similarity score
        probs = logits_per_text.softmax(dim=-1)  # we can take the softmax to get the label probabilities
        most_probable = torch.argmax(probs)  # and now select the most likely audio waveform
    waveform = waveforms[most_probable]
    return waveform

if st.button("Submit"):
    if text_input is None:
        st.error("Please provide a text input.")
    else:
        waveforms = pipe(
            text_input,
            audio_length_in_s=duration,
            guidance_scale=guidance_scale,
            num_inference_steps=100,
            negative_prompt=negative_prompt,
            num_waveforms_per_prompt=n_candidates if n_candidates else 1,
            generator=generator.manual_seed(int(seed)),
        )["audios"]

        if waveforms.shape[0] > 1:
            waveform = score_waveforms(text_input, waveforms)
        else:
            waveform = waveforms[0]

        # Spécifiez le taux d'échantillonnage (sample_rate) et le format audio
        st.audio(waveform, format="audio/wav", sample_rate=16000)

        browser.gatherUsageStats = False