File size: 1,702 Bytes
0d9f09c
 
 
 
 
 
511c7b4
0d9f09c
 
 
ba4b027
0d9f09c
 
 
 
 
 
 
ba4b027
 
0d9f09c
 
 
 
 
 
 
d043662
511c7b4
d043662
 
 
 
0d9f09c
 
d043662
 
 
0d9f09c
 
 
c049bdf
 
d043662
 
 
aa4c24f
3a35f81
a2ceada
 
 
3a35f81
c049bdf
511c7b4
 
c049bdf
0d9f09c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
import openai
from t2a import text_to_audio
import joblib
from sentence_transformers import SentenceTransformer
import numpy as np
import os

reg = joblib.load('text_reg.joblib')
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
finetune = "davinci:ft-personal:autodrummer-v5-2022-11-04-22-34-07"

def get_note_text(prompt):
    prompt = prompt + " ->"
    # get completion from finetune
    response = openai.Completion.create(
        engine=finetune,
        prompt=prompt,
        temperature=0.5,
        max_tokens=200,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        stop=["###"]
    )
    return response.choices[0].text.strip()

def get_drummer_output(prompt, tempo):
    openai.api_key = os.environ['key']
    if tempo == "fast":
        tempo = 138
    elif tempo == "slow":
        tempo = 100
    note_text = get_note_text(prompt)
    # note_text = note_text + " " + note_text
    # prompt_enc = model.encode([prompt])
    # bpm = int(reg.predict(prompt_enc)[0]) + 20
    audio = text_to_audio(note_text, tempo)
    audio = np.array(audio.get_array_of_samples(), dtype=np.float32)
    return (96000, audio)

iface = gr.Interface(
    fn=get_drummer_output,
    inputs=[
        "text",
        gr.Radio(["fast", "slow"], label="Tempo", default="fast"),
    ],
    examples=[
        ["hiphop groove 808", "fast"],
        ["rock metal", "fast"],
        ["disco funk", "fast"],
    ],
    outputs="audio",
    title='Autodrummer',
    description="Stable Diffusion for drum beats. Type in a genre and some descriptors (e.g., 'hiphop groove 808') to the prompt box and get a drum beat in that genre"
)
iface.launch()