File size: 2,941 Bytes
a121edc
 
 
 
676ec69
a121edc
676ec69
a121edc
 
 
 
 
 
 
 
 
5152717
a121edc
 
676ec69
 
 
a121edc
 
 
 
 
 
676ec69
 
 
 
 
 
 
 
a121edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5152717
a121edc
 
5152717
a121edc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from pathlib import Path
import json
from typing import List, Union

import soundfile as sf
import torchaudio
from transformers import AutoProcessor, MusicgenForConditionalGeneration

from mm_story_agent.modality_agents.llm import QwenAgent
from mm_story_agent.prompts_en import story_to_music_reviser_system, story_to_music_reviewer_system


class MusicGenSynthesizer:

    def __init__(self,
                 model_name: str = 'facebook/musicgen-medium',
                 device: str = 'cuda',
                 sample_rate: int = 16000,
                 ) -> None:
        self.device = device
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device)
        self.sample_rate = sample_rate
    
    def call(self,
             prompt: Union[str, List[str]],
             save_path: Union[str, Path],
             ):
        inputs = self.processor(
            text=[prompt],
            padding=True,
            return_tensors="pt",
        ).to(self.device)
        wav = self.model.generate(**inputs, max_new_tokens=1536)[0, 0].cpu()
        wav = torchaudio.functional.resample(wav, self.model.config.audio_encoder.sampling_rate, self.sample_rate)
        sf.write(save_path, wav.numpy(), self.sample_rate)


class MusicGenAgent:

    def __init__(self, config, llm_type="qwen2") -> None:
        self.config = config
        if llm_type == "qwen2":
            self.LLM = QwenAgent

    def generate_music_prompt_from_story(
            self,
            pages: List,
            num_turns: int = 3
        ):
        music_prompt_reviser = self.LLM(story_to_music_reviser_system, track_history=False)
        music_prompt_reviewer = self.LLM(story_to_music_reviewer_system, track_history=False)

        music_prompt = ""
        review = ""
        for turn in range(num_turns):
            music_prompt, success = music_prompt_reviser.run(json.dumps({
                "story": pages,
                "previous_result": music_prompt,
                "improvement_suggestions": review,
            }, ensure_ascii=False))
            review, success = music_prompt_reviewer.run(json.dumps({
                "story_content": pages,
                "music_description": music_prompt
            }, ensure_ascii=False))
            if review == "Check passed.":
                break
        
        return music_prompt

    def call(self, pages: List, device: str, save_path: str):
        save_path = Path(save_path)
        music_prompt = self.generate_music_prompt_from_story(pages, **self.config["revise_cfg"])
        generation_agent = MusicGenSynthesizer(device=device)
        generation_agent.call(
            prompt=music_prompt,
            save_path=save_path / "music.wav",
            **self.config["call_cfg"]
        )
        return {
            "prompt": music_prompt,
            "modality": "music"
        }