Spaces:
Runtime error
Runtime error
Commit
·
d86a2ac
1
Parent(s):
82d5983
Update app.py
Browse files
app.py
CHANGED
@@ -94,3 +94,30 @@ if uploaded_files:
|
|
94 |
n_candidates = st.slider("Number waveforms to generate", 1, 3, 3, 1)
|
95 |
|
96 |
def score_waveforms(text, waveforms):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
n_candidates = st.slider("Number waveforms to generate", 1, 3, 3, 1)
|
95 |
|
96 |
def score_waveforms(text, waveforms):
|
97 |
+
inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
|
98 |
+
inputs = {key: inputs[key].to(device) for key in inputs}
|
99 |
+
with torch.no_grad():
|
100 |
+
logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
|
101 |
+
probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
|
102 |
+
most_probable = torch.argmax(probs) # and now select the most likely audio waveform
|
103 |
+
waveform = waveforms[most_probable]
|
104 |
+
return waveform
|
105 |
+
|
106 |
+
if st.button("Générer de la musique"):
|
107 |
+
waveforms = pipe(
|
108 |
+
music_input,
|
109 |
+
audio_length_in_s=duration,
|
110 |
+
guidance_scale=guidance_scale,
|
111 |
+
num_inference_steps=100,
|
112 |
+
num_waveforms_per_prompt=n_candidates if n_candidates else 1,
|
113 |
+
generator=generator.manual_seed(int(seed)),
|
114 |
+
)["audios"]
|
115 |
+
|
116 |
+
if waveforms.shape[0] > 1:
|
117 |
+
waveform = score_waveforms(music_input, waveforms)
|
118 |
+
else:
|
119 |
+
waveform = waveforms[0]
|
120 |
+
|
121 |
+
# Afficher le lecteur audio
|
122 |
+
st.audio(waveform, format="audio/wav", sample_rate=16000)
|
123 |
+
|