jspr commited on
Commit
d043662
·
1 Parent(s): 20f56e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -25,20 +25,26 @@ def get_note_text(prompt):
25
  )
26
  return response.choices[0].text.strip()
27
 
28
- def get_drummer_output(prompt):
29
  openai.api_key = os.environ['key']
 
 
 
 
30
  note_text = get_note_text(prompt)
31
  # note_text = note_text + " " + note_text
32
- prompt_enc = model.encode([prompt])
33
- bpm = int(reg.predict(prompt_enc)[0]) + 20
34
- print(bpm, "bpm", "notes are", note_text)
35
- audio = text_to_audio(note_text, bpm)
36
  audio = np.array(audio.get_array_of_samples(), dtype=np.float32)
37
  return (96000, audio)
38
 
39
  iface = gr.Interface(
40
  fn=get_drummer_output,
41
- inputs="text",
 
 
 
42
  examples=[
43
  "hiphop groove 808",
44
  "rock metal",
 
25
  )
26
  return response.choices[0].text.strip()
27
 
28
+ def get_drummer_output(prompt, tempo):
29
  openai.api_key = os.environ['key']
30
+ if tempo == "fast":
31
+ tempo = 138
32
+ elif tempo == "slow":
33
+ tempo = 100
34
  note_text = get_note_text(prompt)
35
  # note_text = note_text + " " + note_text
36
+ # prompt_enc = model.encode([prompt])
37
+ # bpm = int(reg.predict(prompt_enc)[0]) + 20
38
+ audio = text_to_audio(note_text, tempo)
 
39
  audio = np.array(audio.get_array_of_samples(), dtype=np.float32)
40
  return (96000, audio)
41
 
42
  iface = gr.Interface(
43
  fn=get_drummer_output,
44
+ inputs=[
45
+ "text",
46
+ gr.Radio(["fast", "slow"], label="Tempo", default="fast"),
47
+ ]
48
  examples=[
49
  "hiphop groove 808",
50
  "rock metal",