Update utils.py
Browse files
utils.py
CHANGED
@@ -213,6 +213,7 @@ def change_tempo(
|
|
213 |
|
214 |
def generate_song(
|
215 |
genre: str = "OTHER",
|
|
|
216 |
temp: float = 0.75,
|
217 |
text_sequence: str = "",
|
218 |
qpm: int = 120,
|
@@ -224,6 +225,7 @@ def generate_song(
|
|
224 |
model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
|
225 |
tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
|
226 |
genre (str, optional): The genre of the song. Defaults to "OTHER".
|
|
|
227 |
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
|
228 |
text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
|
229 |
qpm (int, optional): The quarter notes per minute. Defaults to 120.
|
@@ -233,7 +235,8 @@ def generate_song(
|
|
233 |
instruments string, generated song string, and number of tokens string.
|
234 |
"""
|
235 |
if text_sequence == "":
|
236 |
-
|
|
|
237 |
else:
|
238 |
seed_string = text_sequence
|
239 |
|
|
|
213 |
|
214 |
def generate_song(
|
215 |
genre: str = "OTHER",
|
216 |
+
artist: str = "KATE_BUSH",
|
217 |
temp: float = 0.75,
|
218 |
text_sequence: str = "",
|
219 |
qpm: int = 120,
|
|
|
225 |
model (AutoModelForCausalLM): The pretrained model used for generating the sequences.
|
226 |
tokenizer (AutoTokenizer): The tokenizer used to encode and decode the sequences.
|
227 |
genre (str, optional): The genre of the song. Defaults to "OTHER".
|
228 |
+
artist (str, optional): The artist style to inspire the song. Defaults to "KATE_BUSH".
|
229 |
temp (float, optional): The temperature for the generation, which controls the randomness. Defaults to 0.75.
|
230 |
text_sequence (str, optional): The initial text sequence for the song. Defaults to "".
|
231 |
qpm (int, optional): The quarter notes per minute. Defaults to 120.
|
|
|
235 |
instruments string, generated song string, and number of tokens string.
|
236 |
"""
|
237 |
if text_sequence == "":
|
238 |
+
seed = (' ').join((genre, artist))
|
239 |
+
seed_string = create_seed_string(seed)
|
240 |
else:
|
241 |
seed_string = text_sequence
|
242 |
|