mrfakename commited on
Commit
3f32750
·
verified ·
1 Parent(s): 7362f82

Upload 8 files

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. README.md +30 -15
  3. abc2xml.py +0 -0
  4. config.py +15 -0
  5. demo.py +236 -0
  6. illustration.png +3 -0
  7. inference.py +260 -0
  8. prompts.txt +112 -0
  9. utils.py +406 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  examples/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png filter=lfs diff=lfs merge=lfs -text
37
  examples/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  examples/web_6f93090a-81f6-489e-bb35-1a2838b18c01.png filter=lfs diff=lfs merge=lfs -text
37
  examples/web_dfacd48d-d2c2-492f-b94c-41e6a34ea99f.png filter=lfs diff=lfs merge=lfs -text
38
+ illustration.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,16 +1,31 @@
1
- ---
2
- title: DeepSeek-R1
3
- emoji: 🐋
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.12.0
8
- app_file: app.py
9
- pinned: false
10
- preload_from_hub:
11
- - deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
12
- - deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
13
- short_description: Try out the distilled DeepSeek-R1 models (MIT licensed!)
14
- ---
15
 
16
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Local Gradio Demo
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ 1. Set up the environment:
4
+
5
+ ```
6
+ conda create --name notagen python=3.10
7
+ conda activate notagen
8
+ conda install pytorch==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
9
+ pip install accelerate
10
+ pip install optimum
11
+ pip install -r requirements.txt
12
+ ```
13
+
14
+ 2. Download [NotaGen-X](https://huggingface.co/ElectricAlexis/NotaGen/blob/main/weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth) and put it under ```gradio/```.
15
+
16
+ 3. run ```demo.py```:
17
+
18
+ ```
19
+ cd gradio/
20
+ python demo.py
21
+ ```
22
+
23
+ 4. Then you can view the demo page at 0.0.0.0:7861.
24
+
25
+ <p align="center">
26
+ <img src="illustration.png" alt="NotaGen Gradio Demo">
27
+ </p>
28
+
29
+ You can choose period, composer, and instrumentation as a prompt combination for NotaGen's conditional generation. After generation completes, you can save the ABC notation and MusicXML files locally.
30
+
31
+ It is with some regret that the current combination of prompts is limited to 112, which is constrained by the number of pieces of music under each prompt in the fine-tuning dataset. We hope to expand the combinations and forms of prompts in the future.
abc2xml.py ADDED
The diff for this file is too large to render. See raw diff
 
config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Configurations for inference
4
+ INFERENCE_WEIGHTS_PATH = 'weights_notagenx_p_size_16_p_length_1024_p_layers_20_h_size_1280.pth' # Path to weights for inference# Folder to save output files
5
+ TOP_K = 9 # Top k for sampling
6
+ TOP_P = 0.9 # Top p for sampling
7
+ TEMPERATURE = 1.2 # Temperature for sampling
8
+
9
+ # Configurations for model
10
+ PATCH_STREAM = True # Stream training / inference
11
+ PATCH_SIZE = 16 # Patch Size
12
+ PATCH_LENGTH = 1024 # Patch Length
13
+ CHAR_NUM_LAYERS = 6 # Number of layers in the decoder
14
+ PATCH_NUM_LAYERS = 20 # Number of layers in the encoder
15
+ HIDDEN_SIZE = 1280 # Hidden Size
demo.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import threading
4
+ import queue
5
+ from io import TextIOBase
6
+ from inference import inference_patch
7
+ import datetime
8
+ import subprocess
9
+ import os
10
+
11
+ # Predefined valid combinations set
12
+ with open('prompts.txt', 'r') as f:
13
+ prompts = f.readlines()
14
+ valid_combinations = set()
15
+ for prompt in prompts:
16
+ prompt = prompt.strip()
17
+ parts = prompt.split('_')
18
+ valid_combinations.add((parts[0], parts[1], parts[2]))
19
+
20
+ # Generate available options
21
+ periods = sorted({p for p, _, _ in valid_combinations})
22
+ composers = sorted({c for _, c, _ in valid_combinations})
23
+ instruments = sorted({i for _, _, i in valid_combinations})
24
+
25
+ # Dynamic component updates
26
+ def update_components(period, composer):
27
+ if not period:
28
+ return [
29
+ gr.Dropdown(choices=[], value=None, interactive=False),
30
+ gr.Dropdown(choices=[], value=None, interactive=False)
31
+ ]
32
+
33
+ valid_composers = sorted({c for p, c, _ in valid_combinations if p == period})
34
+ valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else []
35
+
36
+ return [
37
+ gr.Dropdown(
38
+ choices=valid_composers,
39
+ value=composer if composer in valid_composers else None,
40
+ interactive=True
41
+ ),
42
+ gr.Dropdown(
43
+ choices=valid_instruments,
44
+ value=None,
45
+ interactive=bool(valid_instruments)
46
+ )
47
+ ]
48
+
49
+
50
+ class RealtimeStream(TextIOBase):
51
+ def __init__(self, queue):
52
+ self.queue = queue
53
+
54
+ def write(self, text):
55
+ self.queue.put(text)
56
+ return len(text)
57
+
58
+
59
+ def save_and_convert(abc_content, period, composer, instrumentation):
60
+ if not all([period, composer, instrumentation]):
61
+ raise gr.Error("Please complete a valid generation first before saving")
62
+
63
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
64
+ prompt_str = f"{period}_{composer}_{instrumentation}"
65
+ filename_base = f"{timestamp}_{prompt_str}"
66
+
67
+ abc_filename = f"{filename_base}.abc"
68
+ with open(abc_filename, "w", encoding="utf-8") as f:
69
+ f.write(abc_content)
70
+
71
+ xml_filename = f"{filename_base}.xml"
72
+ try:
73
+ subprocess.run(
74
+ ["python", "abc2xml.py", '-o', '.', abc_filename, ],
75
+ check=True,
76
+ capture_output=True,
77
+ text=True
78
+ )
79
+ except subprocess.CalledProcessError as e:
80
+ error_msg = f"Conversion failed: {e.stderr}" if e.stderr else "Unknown error"
81
+ raise gr.Error(f"ABC to XML conversion failed: {error_msg}. Please try to generate another composition.")
82
+
83
+ return f"Saved successfully: {abc_filename} -> {xml_filename}"
84
+
85
+
86
+
87
+ def generate_music(period, composer, instrumentation):
88
+ if (period, composer, instrumentation) not in valid_combinations:
89
+ raise gr.Error("Invalid prompt combination! Please re-select from the period options")
90
+
91
+ output_queue = queue.Queue()
92
+ original_stdout = sys.stdout
93
+ sys.stdout = RealtimeStream(output_queue)
94
+
95
+ result_container = []
96
+ def run_inference():
97
+ try:
98
+ result_container.append(inference_patch(period, composer, instrumentation))
99
+ finally:
100
+ sys.stdout = original_stdout
101
+
102
+ thread = threading.Thread(target=run_inference)
103
+ thread.start()
104
+
105
+ process_output = ""
106
+ while thread.is_alive():
107
+ try:
108
+ text = output_queue.get(timeout=0.1)
109
+ process_output += text
110
+ yield process_output, None
111
+ except queue.Empty:
112
+ continue
113
+
114
+ while not output_queue.empty():
115
+ text = output_queue.get()
116
+ process_output += text
117
+ yield process_output, None
118
+
119
+ final_result = result_container[0] if result_container else ""
120
+ yield process_output, final_result
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.Markdown("## NotaGen")
124
+
125
+ with gr.Row():
126
+ # 左侧栏
127
+ with gr.Column():
128
+ period_dd = gr.Dropdown(
129
+ choices=periods,
130
+ value=None,
131
+ label="Period",
132
+ interactive=True
133
+ )
134
+ composer_dd = gr.Dropdown(
135
+ choices=[],
136
+ value=None,
137
+ label="Composer",
138
+ interactive=False
139
+ )
140
+ instrument_dd = gr.Dropdown(
141
+ choices=[],
142
+ value=None,
143
+ label="Instrumentation",
144
+ interactive=False
145
+ )
146
+
147
+ generate_btn = gr.Button("Generate!", variant="primary")
148
+
149
+ process_output = gr.Textbox(
150
+ label="Generation process",
151
+ interactive=False,
152
+ lines=15,
153
+ max_lines=15,
154
+ placeholder="Generation progress will be shown here...",
155
+ elem_classes="process-output"
156
+ )
157
+
158
+ # 右侧栏
159
+ with gr.Column():
160
+ final_output = gr.Textbox(
161
+ label="Post-processed ABC notation scores",
162
+ interactive=True,
163
+ lines=23,
164
+ placeholder="Post-processed ABC scores will be shown here...",
165
+ elem_classes="final-output"
166
+ )
167
+
168
+ with gr.Row():
169
+ save_btn = gr.Button("💾 Save as ABC & XML files", variant="secondary")
170
+
171
+ save_status = gr.Textbox(
172
+ label="Save Status",
173
+ interactive=False,
174
+ visible=True,
175
+ max_lines=2
176
+ )
177
+
178
+ period_dd.change(
179
+ update_components,
180
+ inputs=[period_dd, composer_dd],
181
+ outputs=[composer_dd, instrument_dd]
182
+ )
183
+ composer_dd.change(
184
+ update_components,
185
+ inputs=[period_dd, composer_dd],
186
+ outputs=[composer_dd, instrument_dd]
187
+ )
188
+
189
+ generate_btn.click(
190
+ generate_music,
191
+ inputs=[period_dd, composer_dd, instrument_dd],
192
+ outputs=[process_output, final_output]
193
+ )
194
+
195
+ save_btn.click(
196
+ save_and_convert,
197
+ inputs=[final_output, period_dd, composer_dd, instrument_dd],
198
+ outputs=[save_status]
199
+ )
200
+
201
+
202
+ css = """
203
+ .process-output {
204
+ background-color: #f0f0f0;
205
+ font-family: monospace;
206
+ padding: 10px;
207
+ border-radius: 5px;
208
+ }
209
+ .final-output {
210
+ background-color: #ffffff;
211
+ font-family: sans-serif;
212
+ padding: 10px;
213
+ border-radius: 5px;
214
+ }
215
+
216
+ .process-output textarea {
217
+ max-height: 500px !important;
218
+ overflow-y: auto !important;
219
+ white-space: pre-wrap;
220
+ }
221
+
222
+ """
223
+ css += """
224
+ button#💾-save-convert:hover {
225
+ background-color: #ffe6e6;
226
+ }
227
+ """
228
+
229
+ demo.css = css
230
+
231
+ if __name__ == "__main__":
232
+
233
+ demo.launch(
234
+ server_name="0.0.0.0",
235
+ server_port=7861
236
+ )
illustration.png ADDED

Git LFS Details

  • SHA256: 10e0d5742ed50035210c40983bdf56d038d0288ebd89881b895e1e50afe609a3
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
inference.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import torch
5
+ from utils import *
6
+ from config import *
7
+ from transformers import GPT2Config, LlamaConfig
8
+ from abctoolkit.utils import Exclaim_re, Quote_re, SquareBracket_re, Barline_regexPattern
9
+ from abctoolkit.transpose import Note_list, Pitch_sign_list
10
+ from abctoolkit.duration import calculate_bartext_duration
11
+
12
+ Note_list = Note_list + ['z', 'x']
13
+
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda")
16
+ else:
17
+ device = torch.device("cpu")
18
+
19
+ patchilizer = Patchilizer()
20
+
21
+ patch_config = GPT2Config(num_hidden_layers=PATCH_NUM_LAYERS,
22
+ max_length=PATCH_LENGTH,
23
+ max_position_embeddings=PATCH_LENGTH,
24
+ n_embd=HIDDEN_SIZE,
25
+ num_attention_heads=HIDDEN_SIZE // 64,
26
+ vocab_size=1)
27
+ byte_config = GPT2Config(num_hidden_layers=CHAR_NUM_LAYERS,
28
+ max_length=PATCH_SIZE + 1,
29
+ max_position_embeddings=PATCH_SIZE + 1,
30
+ hidden_size=HIDDEN_SIZE,
31
+ num_attention_heads=HIDDEN_SIZE // 64,
32
+ vocab_size=128)
33
+
34
+ model = NotaGenLMHeadModel(encoder_config=patch_config, decoder_config=byte_config)
35
+
36
+ print("Parameter Number: " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)))
37
+
38
+ checkpoint = torch.load(INFERENCE_WEIGHTS_PATH, map_location=torch.device(device))
39
+ model.load_state_dict(checkpoint['model'])
40
+ model = model.to(device)
41
+ model.eval()
42
+
43
+
44
+ def rest_unreduce(abc_lines):
45
+
46
+ tunebody_index = None
47
+ for i in range(len(abc_lines)):
48
+ if '[V:' in abc_lines[i]:
49
+ tunebody_index = i
50
+ break
51
+
52
+ metadata_lines = abc_lines[: tunebody_index]
53
+ tunebody_lines = abc_lines[tunebody_index:]
54
+
55
+ part_symbol_list = []
56
+ voice_group_list = []
57
+ for line in metadata_lines:
58
+ if line.startswith('%%score'):
59
+ for round_bracket_match in re.findall(r'\((.*?)\)', line):
60
+ voice_group_list.append(round_bracket_match.split())
61
+ existed_voices = [item for sublist in voice_group_list for item in sublist]
62
+ if line.startswith('V:'):
63
+ symbol = line.split()[0]
64
+ part_symbol_list.append(symbol)
65
+ if symbol[2:] not in existed_voices:
66
+ voice_group_list.append([symbol[2:]])
67
+ z_symbol_list = [] # voices that use z as rest
68
+ x_symbol_list = [] # voices that use x as rest
69
+ for voice_group in voice_group_list:
70
+ z_symbol_list.append('V:' + voice_group[0])
71
+ for j in range(1, len(voice_group)):
72
+ x_symbol_list.append('V:' + voice_group[j])
73
+
74
+ part_symbol_list.sort(key=lambda x: int(x[2:]))
75
+
76
+ unreduced_tunebody_lines = []
77
+
78
+ for i, line in enumerate(tunebody_lines):
79
+ unreduced_line = ''
80
+
81
+ line = re.sub(r'^\[r:[^\]]*\]', '', line)
82
+
83
+ pattern = r'\[V:(\d+)\](.*?)(?=\[V:|$)'
84
+ matches = re.findall(pattern, line)
85
+
86
+ line_bar_dict = {}
87
+ for match in matches:
88
+ key = f'V:{match[0]}'
89
+ value = match[1]
90
+ line_bar_dict[key] = value
91
+
92
+ # calculate duration and collect barline
93
+ dur_dict = {}
94
+ for symbol, bartext in line_bar_dict.items():
95
+ right_barline = ''.join(re.split(Barline_regexPattern, bartext)[-2:])
96
+ bartext = bartext[:-len(right_barline)]
97
+ try:
98
+ bar_dur = calculate_bartext_duration(bartext)
99
+ except:
100
+ bar_dur = None
101
+ if bar_dur is not None:
102
+ if bar_dur not in dur_dict.keys():
103
+ dur_dict[bar_dur] = 1
104
+ else:
105
+ dur_dict[bar_dur] += 1
106
+
107
+ try:
108
+ ref_dur = max(dur_dict, key=dur_dict.get)
109
+ except:
110
+ pass # use last ref_dur
111
+
112
+ if i == 0:
113
+ prefix_left_barline = line.split('[V:')[0]
114
+ else:
115
+ prefix_left_barline = ''
116
+
117
+ for symbol in part_symbol_list:
118
+ if symbol in line_bar_dict.keys():
119
+ symbol_bartext = line_bar_dict[symbol]
120
+ else:
121
+ if symbol in z_symbol_list:
122
+ symbol_bartext = prefix_left_barline + 'z' + str(ref_dur) + right_barline
123
+ elif symbol in x_symbol_list:
124
+ symbol_bartext = prefix_left_barline + 'x' + str(ref_dur) + right_barline
125
+ unreduced_line += '[' + symbol + ']' + symbol_bartext
126
+
127
+ unreduced_tunebody_lines.append(unreduced_line + '\n')
128
+
129
+ unreduced_lines = metadata_lines + unreduced_tunebody_lines
130
+
131
+ return unreduced_lines
132
+
133
+
134
+ def inference_patch(period, composer, instrumentation):
135
+
136
+ prompt_lines=[
137
+ '%' + period + '\n',
138
+ '%' + composer + '\n',
139
+ '%' + instrumentation + '\n']
140
+
141
+ while True:
142
+
143
+ failure_flag = False
144
+
145
+ bos_patch = [patchilizer.bos_token_id] * (PATCH_SIZE - 1) + [patchilizer.eos_token_id]
146
+
147
+ start_time = time.time()
148
+
149
+ prompt_patches = patchilizer.patchilize_metadata(prompt_lines)
150
+ byte_list = list(''.join(prompt_lines))
151
+ print(''.join(byte_list), end='')
152
+
153
+ prompt_patches = [[ord(c) for c in patch] + [patchilizer.special_token_id] * (PATCH_SIZE - len(patch)) for patch
154
+ in prompt_patches]
155
+ prompt_patches.insert(0, bos_patch)
156
+
157
+ input_patches = torch.tensor(prompt_patches, device=device).reshape(1, -1)
158
+
159
+ end_flag = False
160
+ cut_index = None
161
+
162
+ tunebody_flag = False
163
+
164
+ while True:
165
+ predicted_patch = model.generate(input_patches.unsqueeze(0),
166
+ top_k=TOP_K,
167
+ top_p=TOP_P,
168
+ temperature=TEMPERATURE)
169
+ if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # start with [r:0/
170
+ tunebody_flag = True
171
+ r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device)
172
+ temp_input_patches = torch.concat([input_patches, r0_patch], axis=-1)
173
+ predicted_patch = model.generate(temp_input_patches.unsqueeze(0),
174
+ top_k=TOP_K,
175
+ top_p=TOP_P,
176
+ temperature=TEMPERATURE)
177
+ predicted_patch = [ord(c) for c in '[r:0/'] + predicted_patch
178
+ if predicted_patch[0] == patchilizer.bos_token_id and predicted_patch[1] == patchilizer.eos_token_id:
179
+ end_flag = True
180
+ break
181
+ next_patch = patchilizer.decode([predicted_patch])
182
+
183
+ for char in next_patch:
184
+ byte_list.append(char)
185
+ print(char, end='')
186
+
187
+ patch_end_flag = False
188
+ for j in range(len(predicted_patch)):
189
+ if patch_end_flag:
190
+ predicted_patch[j] = patchilizer.special_token_id
191
+ if predicted_patch[j] == patchilizer.eos_token_id:
192
+ patch_end_flag = True
193
+
194
+ predicted_patch = torch.tensor([predicted_patch], device=device) # (1, 16)
195
+ input_patches = torch.cat([input_patches, predicted_patch], dim=1) # (1, 16 * patch_len)
196
+
197
+ if len(byte_list) > 102400:
198
+ failure_flag = True
199
+ break
200
+ if time.time() - start_time > 20 * 60:
201
+ failure_flag = True
202
+ break
203
+
204
+ if input_patches.shape[1] >= PATCH_LENGTH * PATCH_SIZE and not end_flag:
205
+ print('Stream generating...')
206
+ abc_code = ''.join(byte_list)
207
+ abc_lines = abc_code.split('\n')
208
+
209
+ tunebody_index = None
210
+ for i, line in enumerate(abc_lines):
211
+ if line.startswith('[r:') or line.startswith('[V:'):
212
+ tunebody_index = i
213
+ break
214
+ if tunebody_index is None or tunebody_index == len(abc_lines) - 1:
215
+ break
216
+
217
+ metadata_lines = abc_lines[:tunebody_index]
218
+ tunebody_lines = abc_lines[tunebody_index:]
219
+
220
+ metadata_lines = [line + '\n' for line in metadata_lines]
221
+ if not abc_code.endswith('\n'):
222
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [
223
+ tunebody_lines[-1]]
224
+ else:
225
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
226
+
227
+ if cut_index is None:
228
+ cut_index = len(tunebody_lines) // 2
229
+
230
+ abc_code_slice = ''.join(metadata_lines + tunebody_lines[-cut_index:])
231
+ input_patches = patchilizer.encode_generate(abc_code_slice)
232
+
233
+ input_patches = [item for sublist in input_patches for item in sublist]
234
+ input_patches = torch.tensor([input_patches], device=device)
235
+ input_patches = input_patches.reshape(1, -1)
236
+
237
+ if not failure_flag:
238
+ abc_text = ''.join(byte_list)
239
+
240
+ # unreduce
241
+ abc_lines = abc_text.split('\n')
242
+ abc_lines = list(filter(None, abc_lines))
243
+ abc_lines = [line + '\n' for line in abc_lines]
244
+ try:
245
+ unreduced_abc_lines = rest_unreduce(abc_lines)
246
+ except:
247
+ failure_flag = True
248
+ pass
249
+ else:
250
+ unreduced_abc_lines = [line for line in unreduced_abc_lines if not(line.startswith('%') and not line.startswith('%%'))]
251
+ unreduced_abc_lines = ['X:1\n'] + unreduced_abc_lines
252
+ unreduced_abc_text = ''.join(unreduced_abc_lines)
253
+ return unreduced_abc_text
254
+
255
+
256
+
257
+
258
+
259
+ if __name__ == '__main__':
260
+ inference_patch('Classical', 'Beethoven, Ludwig van', 'Keyboard')
prompts.txt ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Baroque_Bach, Johann Sebastian_Chamber
2
+ Baroque_Bach, Johann Sebastian_Choral
3
+ Baroque_Bach, Johann Sebastian_Keyboard
4
+ Baroque_Bach, Johann Sebastian_Orchestral
5
+ Baroque_Bach, Johann Sebastian_Vocal-Orchestral
6
+ Baroque_Corelli, Arcangelo_Chamber
7
+ Baroque_Corelli, Arcangelo_Orchestral
8
+ Baroque_Handel, George Frideric_Chamber
9
+ Baroque_Handel, George Frideric_Keyboard
10
+ Baroque_Handel, George Frideric_Orchestral
11
+ Baroque_Handel, George Frideric_Vocal-Orchestral
12
+ Baroque_Scarlatti, Domenico_Keyboard
13
+ Baroque_Vivaldi, Antonio_Chamber
14
+ Baroque_Vivaldi, Antonio_Orchestral
15
+ Baroque_Vivaldi, Antonio_Vocal-Orchestral
16
+ Classical_Beethoven, Ludwig van_Art Song
17
+ Classical_Beethoven, Ludwig van_Chamber
18
+ Classical_Beethoven, Ludwig van_Keyboard
19
+ Classical_Beethoven, Ludwig van_Orchestral
20
+ Classical_Haydn, Joseph_Chamber
21
+ Classical_Haydn, Joseph_Keyboard
22
+ Classical_Haydn, Joseph_Orchestral
23
+ Classical_Haydn, Joseph_Vocal-Orchestral
24
+ Classical_Mozart, Wolfgang Amadeus_Chamber
25
+ Classical_Mozart, Wolfgang Amadeus_Choral
26
+ Classical_Mozart, Wolfgang Amadeus_Keyboard
27
+ Classical_Mozart, Wolfgang Amadeus_Orchestral
28
+ Classical_Mozart, Wolfgang Amadeus_Vocal-Orchestral
29
+ Classical_Paradis, Maria Theresia von_Art Song
30
+ Classical_Reichardt, Louise_Art Song
31
+ Classical_Saint-Georges, Joseph Bologne_Chamber
32
+ Classical_Schroter, Corona_Art Song
33
+ Romantic_Bartok, Bela_Keyboard
34
+ Romantic_Berlioz, Hector_Choral
35
+ Romantic_Bizet, Georges_Art Song
36
+ Romantic_Boulanger, Lili_Art Song
37
+ Romantic_Boulton, Harold_Art Song
38
+ Romantic_Brahms, Johannes_Art Song
39
+ Romantic_Brahms, Johannes_Chamber
40
+ Romantic_Brahms, Johannes_Choral
41
+ Romantic_Brahms, Johannes_Keyboard
42
+ Romantic_Brahms, Johannes_Orchestral
43
+ Romantic_Burgmuller, Friedrich_Keyboard
44
+ Romantic_Butterworth, George_Art Song
45
+ Romantic_Chaminade, Cecile_Art Song
46
+ Romantic_Chausson, Ernest_Art Song
47
+ Romantic_Chopin, Frederic_Art Song
48
+ Romantic_Chopin, Frederic_Keyboard
49
+ Romantic_Cornelius, Peter_Art Song
50
+ Romantic_Debussy, Claude_Art Song
51
+ Romantic_Debussy, Claude_Keyboard
52
+ Romantic_Dvorak, Antonin_Chamber
53
+ Romantic_Dvorak, Antonin_Choral
54
+ Romantic_Dvorak, Antonin_Keyboard
55
+ Romantic_Dvorak, Antonin_Orchestral
56
+ Romantic_Faisst, Clara_Art Song
57
+ Romantic_Faure, Gabriel_Art Song
58
+ Romantic_Faure, Gabriel_Chamber
59
+ Romantic_Faure, Gabriel_Keyboard
60
+ Romantic_Franz, Robert_Art Song
61
+ Romantic_Gonzaga, Chiquinha_Art Song
62
+ Romantic_Grandval, Clemence de_Art Song
63
+ Romantic_Grieg, Edvard_Keyboard
64
+ Romantic_Grieg, Edvard_Orchestral
65
+ Romantic_Hensel, Fanny_Art Song
66
+ Romantic_Holmes, Augusta Mary Anne_Art Song
67
+ Romantic_Jaell, Marie_Art Song
68
+ Romantic_Kinkel, Johanna_Art Song
69
+ Romantic_Kralik, Mathilde_Art Song
70
+ Romantic_Lang, Josephine_Art Song
71
+ Romantic_Lehmann, Liza_Art Song
72
+ Romantic_Liszt, Franz_Keyboard
73
+ Romantic_Mayer, Emilie_Chamber
74
+ Romantic_Medtner, Nikolay_Keyboard
75
+ Romantic_Mendelssohn, Felix_Art Song
76
+ Romantic_Mendelssohn, Felix_Chamber
77
+ Romantic_Mendelssohn, Felix_Choral
78
+ Romantic_Mendelssohn, Felix_Keyboard
79
+ Romantic_Mendelssohn, Felix_Orchestral
80
+ Romantic_Munktell, Helena_Art Song
81
+ Romantic_Parratt, Walter_Choral
82
+ Romantic_Prokofiev, Sergey_Keyboard
83
+ Romantic_Rachmaninoff, Sergei_Choral
84
+ Romantic_Rachmaninoff, Sergei_Keyboard
85
+ Romantic_Ravel, Maurice_Art Song
86
+ Romantic_Ravel, Maurice_Chamber
87
+ Romantic_Ravel, Maurice_Keyboard
88
+ Romantic_Saint-Saens, Camille_Chamber
89
+ Romantic_Saint-Saens, Camille_Keyboard
90
+ Romantic_Saint-Saens, Camille_Orchestral
91
+ Romantic_Satie, Erik_Art Song
92
+ Romantic_Satie, Erik_Keyboard
93
+ Romantic_Schubert, Franz_Art Song
94
+ Romantic_Schubert, Franz_Chamber
95
+ Romantic_Schubert, Franz_Choral
96
+ Romantic_Schubert, Franz_Keyboard
97
+ Romantic_Schumann, Clara_Art Song
98
+ Romantic_Schumann, Robert_Art Song
99
+ Romantic_Schumann, Robert_Chamber
100
+ Romantic_Schumann, Robert_Choral
101
+ Romantic_Schumann, Robert_Keyboard
102
+ Romantic_Scriabin, Aleksandr_Keyboard
103
+ Romantic_Shostakovich, Dmitry_Chamber
104
+ Romantic_Shostakovich, Dmitry_Keyboard
105
+ Romantic_Sibelius, Jean_Keyboard
106
+ Romantic_Smetana, Bedrich_Keyboard
107
+ Romantic_Tchaikovsky, Pyotr_Keyboard
108
+ Romantic_Tchaikovsky, Pyotr_Orchestral
109
+ Romantic_Viardot, Pauline_Art Song
110
+ Romantic_Warlock, Peter_Art Song
111
+ Romantic_Wolf, Hugo_Art Song
112
+ Romantic_Zumsteeg, Emilie_Art Song
utils.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import bisect
4
+ import json
5
+ import re
6
+ from config import *
7
+ from transformers import GPT2Model, GPT2LMHeadModel, LlamaModel, LlamaForCausalLM, PreTrainedModel
8
+ from samplings import top_p_sampling, top_k_sampling, temperature_sampling
9
+ from tokenizers import Tokenizer
10
+
11
+
12
+ class Patchilizer:
13
+ def __init__(self, stream=PATCH_STREAM):
14
+ self.stream = stream
15
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
16
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
17
+ self.bos_token_id = 1
18
+ self.eos_token_id = 2
19
+ self.special_token_id = 0
20
+
21
+ def split_bars(self, body_lines):
22
+ """
23
+ Split a body of music into individual bars.
24
+ """
25
+ new_bars = []
26
+ try:
27
+ for line in body_lines:
28
+ line_bars = re.split(self.regexPattern, line)
29
+ line_bars = list(filter(None, line_bars))
30
+ new_line_bars = []
31
+
32
+ if len(line_bars) == 1:
33
+ new_line_bars = line_bars
34
+ else:
35
+ if line_bars[0] in self.delimiters:
36
+ new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)]
37
+ else:
38
+ new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)]
39
+ if 'V' not in new_line_bars[-1]:
40
+ new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合
41
+ new_line_bars = new_line_bars[:-1]
42
+ new_bars += new_line_bars
43
+ except:
44
+ pass
45
+
46
+ return new_bars
47
+
48
+ def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False):
49
+ if not generate_last and len(abc_text) % patch_size != 0:
50
+ abc_text += chr(self.eos_token_id)
51
+ patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)]
52
+ return patches
53
+
54
+ def patch2chars(self, patch):
55
+ """
56
+ Convert a patch into a bar.
57
+ """
58
+ bytes = ''
59
+ for idx in patch:
60
+ if idx == self.eos_token_id:
61
+ break
62
+ if idx < self.eos_token_id:
63
+ pass
64
+ bytes += chr(idx)
65
+ return bytes
66
+
67
+
68
+ def patchilize_metadata(self, metadata_lines):
69
+
70
+ metadata_patches = []
71
+ for line in metadata_lines:
72
+ metadata_patches += self.split_patches(line)
73
+
74
+ return metadata_patches
75
+
76
+ def patchilize_tunebody(self, tunebody_lines, encode_mode='train'):
77
+
78
+ tunebody_patches = []
79
+ bars = self.split_bars(tunebody_lines)
80
+ if encode_mode == 'train':
81
+ for bar in bars:
82
+ tunebody_patches += self.split_patches(bar)
83
+ elif encode_mode == 'generate':
84
+ for bar in bars[:-1]:
85
+ tunebody_patches += self.split_patches(bar)
86
+ tunebody_patches += self.split_patches(bars[-1], generate_last=True)
87
+
88
+ return tunebody_patches
89
+
90
+ def encode_train(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True):
91
+
92
+ lines = abc_text.split('\n')
93
+ lines = list(filter(None, lines))
94
+ lines = [line + '\n' for line in lines]
95
+
96
+ tunebody_index = -1
97
+ for i, line in enumerate(lines):
98
+ if '[V:' in line:
99
+ tunebody_index = i
100
+ break
101
+
102
+ metadata_lines = lines[ : tunebody_index]
103
+ tunebody_lines = lines[tunebody_index : ]
104
+
105
+ if self.stream:
106
+ tunebody_lines = ['[r:' + str(line_index) + '/' + str(len(tunebody_lines) - line_index - 1) + ']' + line for line_index, line in
107
+ enumerate(tunebody_lines)]
108
+
109
+ metadata_patches = self.patchilize_metadata(metadata_lines)
110
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train')
111
+
112
+ if add_special_patches:
113
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
114
+ eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1)
115
+
116
+ metadata_patches = [bos_patch] + metadata_patches
117
+ tunebody_patches = tunebody_patches + [eos_patch]
118
+
119
+ if self.stream:
120
+ if len(metadata_patches) + len(tunebody_patches) > patch_length:
121
+ available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch]
122
+ line_index_for_cut_index = list(range(len(available_cut_indexes)))
123
+ end_index = len(metadata_patches) + len(tunebody_patches) - patch_length
124
+ biggest_index = bisect.bisect_left(available_cut_indexes, end_index)
125
+ available_cut_indexes = available_cut_indexes[:biggest_index + 1]
126
+
127
+ if len(available_cut_indexes) == 1:
128
+ choices = ['head']
129
+ elif len(available_cut_indexes) == 2:
130
+ choices = ['head', 'tail']
131
+ else:
132
+ choices = ['head', 'tail', 'middle']
133
+ choice = random.choice(choices)
134
+ if choice == 'head':
135
+ patches = metadata_patches + tunebody_patches[0:]
136
+ else:
137
+ if choice == 'tail':
138
+ cut_index = len(available_cut_indexes) - 1
139
+ else:
140
+ cut_index = random.choice(range(1, len(available_cut_indexes) - 1))
141
+
142
+ line_index = line_index_for_cut_index[cut_index]
143
+ stream_tunebody_lines = tunebody_lines[line_index : ]
144
+
145
+ stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train')
146
+ if add_special_patches:
147
+ stream_tunebody_patches = stream_tunebody_patches + [eos_patch]
148
+ patches = metadata_patches + stream_tunebody_patches
149
+ else:
150
+ patches = metadata_patches + tunebody_patches
151
+ else:
152
+ patches = metadata_patches + tunebody_patches
153
+
154
+ if cut:
155
+ patches = patches[ : patch_length]
156
+ else:
157
+ pass
158
+
159
+ # encode to ids
160
+ id_patches = []
161
+ for patch in patches:
162
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
163
+ id_patches.append(id_patch)
164
+
165
+ return id_patches
166
+
167
+ def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True):
168
+
169
+ lines = abc_code.split('\n')
170
+ lines = list(filter(None, lines))
171
+
172
+ tunebody_index = None
173
+ for i, line in enumerate(lines):
174
+ if line.startswith('[V:') or line.startswith('[r:'):
175
+ tunebody_index = i
176
+ break
177
+
178
+ metadata_lines = lines[ : tunebody_index]
179
+ tunebody_lines = lines[tunebody_index : ]
180
+
181
+ metadata_lines = [line + '\n' for line in metadata_lines]
182
+ if self.stream:
183
+ if not abc_code.endswith('\n'):
184
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]]
185
+ else:
186
+ tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))]
187
+ else:
188
+ tunebody_lines = [line + '\n' for line in tunebody_lines]
189
+
190
+ metadata_patches = self.patchilize_metadata(metadata_lines)
191
+ tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate')
192
+
193
+ if add_special_patches:
194
+ bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id)
195
+
196
+ metadata_patches = [bos_patch] + metadata_patches
197
+
198
+ patches = metadata_patches + tunebody_patches
199
+ patches = patches[ : patch_length]
200
+
201
+ # encode to ids
202
+ id_patches = []
203
+ for patch in patches:
204
+ if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id):
205
+ id_patch = [ord(c) for c in patch]
206
+ else:
207
+ id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch))
208
+ id_patches.append(id_patch)
209
+
210
+ return id_patches
211
+
212
+ def decode(self, patches):
213
+ """
214
+ Decode patches into music.
215
+ """
216
+ return ''.join(self.patch2chars(patch) for patch in patches)
217
+
218
+
219
+
220
+
221
+ class PatchLevelDecoder(PreTrainedModel):
222
+ """
223
+ A Patch-level Decoder model for generating patch features in an auto-regressive manner.
224
+ It inherits PreTrainedModel from transformers.
225
+ """
226
+ def __init__(self, config):
227
+ super().__init__(config)
228
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd)
229
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
230
+ self.base = GPT2Model(config)
231
+
232
+ def forward(self,
233
+ patches: torch.Tensor,
234
+ masks=None) -> torch.Tensor:
235
+ """
236
+ The forward pass of the patch-level decoder model.
237
+ :param patches: the patches to be encoded
238
+ :param masks: the masks for the patches
239
+ :return: the encoded patches
240
+ """
241
+ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype)
242
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128))
243
+ patches = self.patch_embedding(patches.to(self.device))
244
+
245
+ if masks==None:
246
+ return self.base(inputs_embeds=patches)
247
+ else:
248
+ return self.base(inputs_embeds=patches,
249
+ attention_mask=masks)
250
+
251
+
252
+ class CharLevelDecoder(PreTrainedModel):
253
+ """
254
+ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner
255
+ based on the encoded patch features. It inherits PreTrainedModel from transformers.
256
+ """
257
+ def __init__(self, config):
258
+ super().__init__(config)
259
+ self.special_token_id = 0
260
+ self.bos_token_id = 1
261
+
262
+ self.base = GPT2LMHeadModel(config)
263
+
264
+ def forward(self,
265
+ encoded_patches: torch.Tensor,
266
+ target_patches: torch.Tensor):
267
+ """
268
+ The forward pass of the char-level decoder model.
269
+ :param encoded_patches: the encoded patches
270
+ :param target_patches: the target patches
271
+ :return: the output of the model
272
+ """
273
+ # preparing the labels for model training
274
+ target_patches = torch.cat((torch.ones_like(target_patches[:,0:1])*self.bos_token_id, target_patches), dim=1)
275
+ # print('target_patches shape:', target_patches.shape)
276
+
277
+ target_masks = target_patches == self.special_token_id
278
+ labels = target_patches.clone().masked_fill_(target_masks, -100)
279
+
280
+ # masking the labels for model training
281
+ target_masks = torch.ones_like(labels)
282
+ target_masks = target_masks.masked_fill_(labels == -100, 0)
283
+
284
+ # select patches
285
+ if PATCH_SAMPLING_BATCH_SIZE!=0 and PATCH_SAMPLING_BATCH_SIZE<target_patches.shape[0]:
286
+ indices = list(range(len(target_patches)))
287
+ random.shuffle(indices)
288
+ selected_indices = sorted(indices[:PATCH_SAMPLING_BATCH_SIZE])
289
+
290
+ target_patches = target_patches[selected_indices,:]
291
+ target_masks = target_masks[selected_indices,:]
292
+ encoded_patches = encoded_patches[selected_indices,:]
293
+
294
+ # get input embeddings
295
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
296
+
297
+ # concatenate the encoded patches with the input embeddings
298
+ inputs_embeds = torch.cat((encoded_patches.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
299
+
300
+ output = self.base(inputs_embeds=inputs_embeds,
301
+ attention_mask=target_masks,
302
+ labels=labels)
303
+ # output_hidden_states=True=True)
304
+
305
+ return output
306
+
307
+ def generate(self,
308
+ encoded_patch: torch.Tensor, # [hidden_size]
309
+ tokens: torch.Tensor): # [1]
310
+ """
311
+ The generate function for generating a patch based on the encoded patch and already generated tokens.
312
+ :param encoded_patch: the encoded patch
313
+ :param tokens: already generated tokens in the patch
314
+ :return: the probability distribution of next token
315
+ """
316
+ encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size]
317
+ tokens = tokens.reshape(1, -1)
318
+
319
+ # Get input embeddings
320
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
321
+
322
+ # Concatenate the encoded patch with the input embeddings
323
+ tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1)
324
+
325
+ # Get output from model
326
+ outputs = self.base(inputs_embeds=tokens)
327
+
328
+ # Get probabilities of next token
329
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
330
+
331
+ return probs
332
+
333
+ class NotaGenLMHeadModel(PreTrainedModel):
334
+ """
335
+ NotaGen is a language model with a hierarchical structure.
336
+ It includes a patch-level decoder and a char-level decoder.
337
+ The patch-level decoder is used to generate patch features in an auto-regressive manner.
338
+ The char-level decoder is used to generate the chars within each patch in an auto-regressive manner.
339
+ It inherits PreTrainedModel from transformers.
340
+ """
341
+ def __init__(self, encoder_config, decoder_config):
342
+ super().__init__(encoder_config)
343
+ self.special_token_id = 0
344
+ self.bos_token_id = 1
345
+ self.eos_token_id = 2
346
+ self.patch_level_decoder = PatchLevelDecoder(encoder_config)
347
+ self.char_level_decoder = CharLevelDecoder(decoder_config)
348
+
349
+ def forward(self,
350
+ patches: torch.Tensor,
351
+ masks: torch.Tensor):
352
+ """
353
+ The forward pass of the bGPT model.
354
+ :param patches: the patches to be encoded
355
+ :param masks: the masks for the patches
356
+ :return: the decoded patches
357
+ """
358
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE)
359
+ encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"]
360
+
361
+ left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1)
362
+ masks[:, 0] = 0
363
+
364
+ encoded_patches = encoded_patches[left_shift_masks == 1]
365
+ patches = patches[masks == 1]
366
+
367
+ return self.char_level_decoder(encoded_patches, patches)
368
+
369
+ def generate(self,
370
+ patches: torch.Tensor,
371
+ top_k=0,
372
+ top_p=1,
373
+ temperature=1.0):
374
+ """
375
+ The generate function for generating patches based on patches.
376
+ :param patches: the patches to be encoded
377
+ :param top_k: the top k for sampling
378
+ :param top_p: the top p for sampling
379
+ :param temperature: the temperature for sampling
380
+ :return: the generated patches
381
+ """
382
+ if patches.shape[-1] % PATCH_SIZE != 0:
383
+ tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1)
384
+ tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1)
385
+ patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)]
386
+ else:
387
+ tokens = torch.tensor([self.bos_token_id], device=self.device)
388
+
389
+ patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size]
390
+ encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size]
391
+ generated_patch = []
392
+
393
+ while True:
394
+ prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128]
395
+ prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128]
396
+ prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128]
397
+ token = temperature_sampling(prob, temperature=temperature) # int
398
+ char = chr(token)
399
+ generated_patch.append(token)
400
+
401
+ if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id:
402
+ break
403
+ else:
404
+ tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0)
405
+
406
+ return generated_patch