lklimkiewicz
commited on
Upload OPTForMusicGeneration
Browse files
model.py
CHANGED
@@ -4,16 +4,14 @@ import torch
|
|
4 |
from miditok import TokSequence
|
5 |
|
6 |
|
7 |
-
# class OPTForMusicGenerationConfig(OPTConfig):
|
8 |
-
|
9 |
-
|
10 |
class OPTForMusicGeneration(OPTForCausalLM):
|
11 |
|
12 |
-
def generate_music(self, **kwargs):
|
13 |
input = torch.tensor([[self.config.bos_token_id]], device=self.device)
|
14 |
midi = self.generate(input, **kwargs)
|
15 |
generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
|
16 |
-
|
|
|
17 |
|
18 |
|
19 |
OPTForMusicGeneration.register_for_auto_class("AutoModel")
|
|
|
4 |
from miditok import TokSequence
|
5 |
|
6 |
|
|
|
|
|
|
|
7 |
class OPTForMusicGeneration(OPTForCausalLM):
|
8 |
|
9 |
+
def generate_music(self, tokenizer, **kwargs):
|
10 |
input = torch.tensor([[self.config.bos_token_id]], device=self.device)
|
11 |
midi = self.generate(input, **kwargs)
|
12 |
generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
|
13 |
+
generated_score = tokenizer(generated_ts)
|
14 |
+
return generated_score
|
15 |
|
16 |
|
17 |
OPTForMusicGeneration.register_for_auto_class("AutoModel")
|