lklimkiewicz commited on
Commit
a1c7dae
·
verified ·
1 Parent(s): 8b6b1aa

Upload OPTForMusicGeneration

Browse files
Files changed (1) hide show
  1. model.py +3 -5
model.py CHANGED
@@ -4,16 +4,14 @@ import torch
4
  from miditok import TokSequence
5
 
6
 
7
- # class OPTForMusicGenerationConfig(OPTConfig):
8
-
9
-
10
  class OPTForMusicGeneration(OPTForCausalLM):
11
 
12
- def generate_music(self, **kwargs):
13
  input = torch.tensor([[self.config.bos_token_id]], device=self.device)
14
  midi = self.generate(input, **kwargs)
15
  generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
16
- return generated_ts
 
17
 
18
 
19
  OPTForMusicGeneration.register_for_auto_class("AutoModel")
 
4
  from miditok import TokSequence
5
 
6
 
 
 
 
7
  class OPTForMusicGeneration(OPTForCausalLM):
8
 
9
+ def generate_music(self, tokenizer, **kwargs):
10
  input = torch.tensor([[self.config.bos_token_id]], device=self.device)
11
  midi = self.generate(input, **kwargs)
12
  generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
13
+ generated_score = tokenizer(generated_ts)
14
+ return generated_score
15
 
16
 
17
  OPTForMusicGeneration.register_for_auto_class("AutoModel")