Update utils.py
Browse files
utils.py
CHANGED
@@ -18,7 +18,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
18 |
model, tokenizer = get_model_and_tokenizer()
|
19 |
|
20 |
|
21 |
-
def create_seed_string(genre: str = "OTHER") -> str:
|
22 |
"""
|
23 |
Creates a seed string for generating a new piece.
|
24 |
|
@@ -28,10 +28,14 @@ def create_seed_string(genre: str = "OTHER") -> str:
|
|
28 |
Returns:
|
29 |
str: The seed string.
|
30 |
"""
|
31 |
-
if genre == "RANDOM":
|
32 |
seed_string = "PIECE_START"
|
|
|
|
|
|
|
|
|
33 |
else:
|
34 |
-
seed_string = f"PIECE_START GENRE={genre} TRACK_START"
|
35 |
return seed_string
|
36 |
|
37 |
|
@@ -235,11 +239,9 @@ def generate_song(
|
|
235 |
instruments string, generated song string, and number of tokens string.
|
236 |
"""
|
237 |
if text_sequence == "":
|
238 |
-
|
239 |
-
seed_string = create_seed_string(seed)
|
240 |
else:
|
241 |
seed_string = text_sequence
|
242 |
-
print(seed_string)
|
243 |
|
244 |
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
|
245 |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
|
|
|
18 |
model, tokenizer = get_model_and_tokenizer()
|
19 |
|
20 |
|
21 |
+
def create_seed_string(genre: str = "OTHER", artist: str = "OTHER") -> str:
|
22 |
"""
|
23 |
Creates a seed string for generating a new piece.
|
24 |
|
|
|
28 |
Returns:
|
29 |
str: The seed string.
|
30 |
"""
|
31 |
+
if genre == "RANDOM" and artist == "RANDOM":
|
32 |
seed_string = "PIECE_START"
|
33 |
+
elif genre == "RANDOM" and artist != "RANDOM":
|
34 |
+
seed_string = f"PIECE_START GENRE=RANDOM GENRE={artist} TRACK_START"
|
35 |
+
elif genre != "RANDOM" and artist == "RANDOM":
|
36 |
+
seed_string = f"PIECE_START GENRE={genre} GENRE=RANDOM TRACK_START"
|
37 |
else:
|
38 |
+
seed_string = f"PIECE_START GENRE={genre} GENRE={artist} TRACK_START"
|
39 |
return seed_string
|
40 |
|
41 |
|
|
|
239 |
instruments string, generated song string, and number of tokens string.
|
240 |
"""
|
241 |
if text_sequence == "":
|
242 |
+
seed_string = create_seed_string(genre, artist)
|
|
|
243 |
else:
|
244 |
seed_string = text_sequence
|
|
|
245 |
|
246 |
generated_sequence = generate_new_instrument(seed=seed_string, temp=temp)
|
247 |
audio, midi_file, fig, instruments_str, num_tokens = get_outputs_from_string(
|