File size: 557 Bytes
0b117e9 a1c7dae 0b117e9 a1c7dae 0b117e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import OPTForCausalLM, OPTConfig, AutoModel
import torch
from miditok import TokSequence
class OPTForMusicGeneration(OPTForCausalLM):
def generate_music(self, tokenizer, **kwargs):
input = torch.tensor([[self.config.bos_token_id]], device=self.device)
midi = self.generate(input, **kwargs)
generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
generated_score = tokenizer(generated_ts)
return generated_score
OPTForMusicGeneration.register_for_auto_class("AutoModel")
|