import streamlit as st import torch from diffusers import AudioLDMPipeline from transformers import AutoProcessor, ClapModel st.set_option('browser.gatherUsageStats', False) # make Space compatible with CPU duplicates if torch.cuda.is_available(): device = "cuda" torch_dtype = torch.float16 else: device = "cpu" torch_dtype = torch.float32 # load the diffusers pipeline repo_id = "cvssp/audioldm-m-full" pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device) pipe.unet = torch.compile(pipe.unet) # CLAP model (only required for automatic scoring) clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device) processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full") generator = torch.Generator(device) # Streamlit app setup st.set_page_config( page_title="Text to Music", page_icon="🎵", ) text_input = st.text_input("Input text", "A hammer is hitting a wooden surface") negative_prompt = st.text_input("Negative prompt", "low quality, average quality") st.markdown("### Configuration") seed = st.number_input("Seed", value=45) duration = st.slider("Duration (seconds)", 2.5, 10.0, 5.0, 2.5) guidance_scale = st.slider("Guidance scale", 0.0, 4.0, 2.5, 0.5) n_candidates = st.slider("Number waveforms to generate", 1, 3, 3, 1) def score_waveforms(text, waveforms): inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True) inputs = {key: inputs[key].to(device) for key in inputs} with torch.no_grad(): logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities most_probable = torch.argmax(probs) # and now select the most likely audio waveform waveform = waveforms[most_probable] return waveform if st.button("Submit"): if text_input is None: st.error("Please provide a text input.") else: waveforms = pipe( text_input, audio_length_in_s=duration, guidance_scale=guidance_scale, num_inference_steps=100, negative_prompt=negative_prompt, num_waveforms_per_prompt=n_candidates if n_candidates else 1, generator=generator.manual_seed(int(seed)), )["audios"] if waveforms.shape[0] > 1: waveform = score_waveforms(text_input, waveforms) else: waveform = waveforms[0] # Spécifiez le taux d'échantillonnage (sample_rate) et le format audio st.audio(waveform, format="audio/wav", sample_rate=16000) browser.gatherUsageStats = False