Spaces:
Sleeping
Sleeping
Commit
·
87ae0b7
1
Parent(s):
2d7a385
Initial commit
Browse files- app.py +245 -0
- assets/asciilogo.txt +11 -0
- requirements.txt +11 -0
- source/languagemodel.py +288 -0
- source/utilities.py +331 -0
app.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from source.languagemodel import LanguageModel
|
3 |
+
from source.utilities import (
|
4 |
+
convert_tokens_to_songdata,
|
5 |
+
convert_songdata_to_notesequence,
|
6 |
+
convert_songdata_to_pianoroll,
|
7 |
+
convert_notesequence_to_wave,
|
8 |
+
convert_notesequence_to_midi
|
9 |
+
)
|
10 |
+
|
11 |
+
# Define the MIDI instruments.
|
12 |
+
midi_instruments = {
|
13 |
+
"Harpsichord": 6,
|
14 |
+
"Church Organ": 19,
|
15 |
+
"Piano": 0,
|
16 |
+
}
|
17 |
+
|
18 |
+
# Load the model once and cache it.
|
19 |
+
@st.cache_resource
|
20 |
+
def load_model():
|
21 |
+
model = LanguageModel("TristanBehrens/bach-garland-mambaplus")
|
22 |
+
return model
|
23 |
+
model = load_model()
|
24 |
+
|
25 |
+
|
26 |
+
# Initialize token_sequence in session state if it doesn't exist
|
27 |
+
if "token_sequence" not in st.session_state:
|
28 |
+
st.session_state.token_sequence = "GARLAND_START"
|
29 |
+
st.session_state.song_data = None
|
30 |
+
st.session_state.piano_roll = None
|
31 |
+
st.session_state.wave = None
|
32 |
+
st.session_state.note_sequence = None
|
33 |
+
st.session_state.midi_file_content = None
|
34 |
+
st.session_state.temperature = 0.1
|
35 |
+
st.session_state.bpm = 100
|
36 |
+
st.session_state.instrument = "Piano"
|
37 |
+
|
38 |
+
|
39 |
+
# Define the main function.
|
40 |
+
def main():
|
41 |
+
|
42 |
+
columns = st.columns([0.7, 0.3])
|
43 |
+
|
44 |
+
# Set up the Streamlit application
|
45 |
+
column = columns.pop(0)
|
46 |
+
with column:
|
47 |
+
|
48 |
+
# Change the colors of the a-tag to (255, 75, 75).
|
49 |
+
st.markdown("<style>a:link { color: #FF4B4B; } a:visited { color: #FF4B4B; }</style>", unsafe_allow_html=True)
|
50 |
+
|
51 |
+
# Add a title.
|
52 |
+
st.title("Garland Composer")
|
53 |
+
linkedin_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
|
54 |
+
x_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
|
55 |
+
st.write(f"By Dr. Tristan Behrens. Find me on [LinkedIn]({linkedin_url}) and [X]({x_url}).")
|
56 |
+
hf_url = "https://huggingface.co/TristanBehrens/bach-garland-mambaplus/"
|
57 |
+
st.write(f"Model available on [Hugging Face]({hf_url}).")
|
58 |
+
|
59 |
+
# Add a picture.
|
60 |
+
column = columns.pop(0)
|
61 |
+
with column:
|
62 |
+
st.write(" ")
|
63 |
+
st.write(" ")
|
64 |
+
st.write(" ")
|
65 |
+
st.image("garland.jpg", use_column_width=True)
|
66 |
+
|
67 |
+
# Add a horizontal line.
|
68 |
+
st.markdown("---")
|
69 |
+
|
70 |
+
# Create two columns.
|
71 |
+
columns = st.columns(3)
|
72 |
+
|
73 |
+
# Add a slider to control the temperature.
|
74 |
+
state_temperature = st.session_state.temperature
|
75 |
+
with columns.pop(0):
|
76 |
+
temperature = st.slider("Temperature", 0.0, 1.0, state_temperature)
|
77 |
+
st.session_state.temperature = temperature
|
78 |
+
|
79 |
+
# Add a slider to control the bpm.
|
80 |
+
state_bpm = st.session_state.bpm
|
81 |
+
with columns.pop(0):
|
82 |
+
bpm = st.slider("BPM", 80, 120, state_bpm, 5)
|
83 |
+
st.session_state.bpm = bpm
|
84 |
+
|
85 |
+
# Dropdown for the instrument.
|
86 |
+
state_instrument = st.session_state.instrument
|
87 |
+
with columns.pop(0):
|
88 |
+
instrument = st.selectbox("Instrument", list(midi_instruments.keys()), index=list(midi_instruments.keys()).index(state_instrument))
|
89 |
+
st.session_state.instrument = instrument
|
90 |
+
|
91 |
+
# Get the token sequence from the session state.
|
92 |
+
token_sequence = st.session_state.token_sequence
|
93 |
+
|
94 |
+
# Columns for the buttons.
|
95 |
+
columns = st.columns(5)
|
96 |
+
|
97 |
+
# Add a button to generate the next bar.
|
98 |
+
column = columns.pop(0)
|
99 |
+
with column:
|
100 |
+
if st.button("Add a bar", use_container_width=True):
|
101 |
+
token_sequence = extend_sequence(model, token_sequence, temperature)
|
102 |
+
refresh(token_sequence, bpm, instrument)
|
103 |
+
|
104 |
+
# Add a button to compose long.
|
105 |
+
column = columns.pop(0)
|
106 |
+
with column:
|
107 |
+
if st.button("Auto compose", use_container_width=True):
|
108 |
+
token_sequence = auto_compose(model, token_sequence, temperature)
|
109 |
+
refresh(token_sequence, bpm, instrument)
|
110 |
+
|
111 |
+
# Add a button to remove the last bar.
|
112 |
+
column = columns.pop(0)
|
113 |
+
with column:
|
114 |
+
if st.button("Remove last", use_container_width=True):
|
115 |
+
token_sequence = shortened_sequence(token_sequence)
|
116 |
+
refresh(token_sequence, bpm, instrument)
|
117 |
+
|
118 |
+
# Add a button to reset the sequence.
|
119 |
+
column = columns.pop(0)
|
120 |
+
if token_sequence != "GARLAND_START":
|
121 |
+
with column:
|
122 |
+
if st.button("Reset", use_container_width=True):
|
123 |
+
with columns.pop(0):
|
124 |
+
token_sequence = "GARLAND_START"
|
125 |
+
refresh(token_sequence, bpm, instrument)
|
126 |
+
|
127 |
+
# Provide a download button for the MIDI file.
|
128 |
+
column = columns.pop(0)
|
129 |
+
if "midi_file_content" in st.session_state and st.session_state.midi_file_content is not None:
|
130 |
+
with column:
|
131 |
+
midi_file_content = st.session_state.midi_file_content
|
132 |
+
if st.download_button(
|
133 |
+
label="Download MIDI",
|
134 |
+
data=midi_file_content,
|
135 |
+
file_name="music.mid",
|
136 |
+
mime="audio/midi",
|
137 |
+
use_container_width=True
|
138 |
+
):
|
139 |
+
pass
|
140 |
+
|
141 |
+
# Add a horizontal line.
|
142 |
+
st.markdown("---")
|
143 |
+
|
144 |
+
# Display the piano roll.
|
145 |
+
if "piano_roll" in st.session_state and st.session_state.piano_roll is not None:
|
146 |
+
st.image(st.session_state.piano_roll)
|
147 |
+
|
148 |
+
# Display an audio player.
|
149 |
+
if "wave" in st.session_state and st.session_state.wave is not None:
|
150 |
+
st.audio(st.session_state.wave, format="audio/wav", sample_rate=44100, autoplay=True)
|
151 |
+
|
152 |
+
# Add a horizontal line.
|
153 |
+
st.markdown("---")
|
154 |
+
|
155 |
+
# Set the text color to (255, 31, 75).
|
156 |
+
if token_sequence.endswith("GARLAND_END"):
|
157 |
+
st.write("The AI believes that the music is finished.")
|
158 |
+
else:
|
159 |
+
st.write("The AI believes that the music is not finished.")
|
160 |
+
|
161 |
+
|
162 |
+
def auto_compose(model, token_sequence, temperature):
|
163 |
+
|
164 |
+
max_iterations = 100
|
165 |
+
for _ in range(max_iterations):
|
166 |
+
token_sequence = extend_sequence(model, token_sequence, temperature)
|
167 |
+
if token_sequence.endswith("GARLAND_END"):
|
168 |
+
break
|
169 |
+
return token_sequence
|
170 |
+
|
171 |
+
|
172 |
+
def extend_sequence(model, token_sequence, temperature):
|
173 |
+
|
174 |
+
# Replace the last GARLAND_END token with NEXT.
|
175 |
+
if token_sequence.endswith("GARLAND_END"):
|
176 |
+
token_sequence = token_sequence.replace("GARLAND_END", "NEXT")
|
177 |
+
|
178 |
+
# The maximum length of the generated music.
|
179 |
+
max_length = 16_384
|
180 |
+
|
181 |
+
# When to stop the generation.
|
182 |
+
end_tokens = ["NEXT", "GARLAND_END"]
|
183 |
+
|
184 |
+
# Compose the music iterativelybar by bar.
|
185 |
+
output_dict = model.generate(
|
186 |
+
prompt=token_sequence,
|
187 |
+
temperature=temperature,
|
188 |
+
max_length=max_length,
|
189 |
+
end_tokens=end_tokens,
|
190 |
+
forbidden_tokens=["[PAD]", "[EOS]"],
|
191 |
+
return_structured_output=True
|
192 |
+
)
|
193 |
+
output = output_dict["output"]
|
194 |
+
return output
|
195 |
+
|
196 |
+
|
197 |
+
def shortened_sequence(token_sequence):
|
198 |
+
|
199 |
+
# Find the position of the next to last NEXT token.
|
200 |
+
next_tokens = token_sequence.split()
|
201 |
+
next_positions = [i for i, x in enumerate(next_tokens) if x == "NEXT" or x == "GARLAND_END"]
|
202 |
+
if len(next_positions) <= 1:
|
203 |
+
token_sequence = "GARLAND_START"
|
204 |
+
else:
|
205 |
+
next_position = next_positions[-2]
|
206 |
+
token_sequence = " ".join(next_tokens[:next_position + 1])
|
207 |
+
return token_sequence
|
208 |
+
|
209 |
+
|
210 |
+
def refresh(token_sequence="GARLAND_START", bpm=120, instrument="Piano"):
|
211 |
+
|
212 |
+
# Get the token sequence into the session state.
|
213 |
+
st.session_state.token_sequence = token_sequence
|
214 |
+
|
215 |
+
# Convert to song data.
|
216 |
+
song_data = convert_tokens_to_songdata(token_sequence)
|
217 |
+
song_data["bpm"] = bpm
|
218 |
+
st.session_state.song_data = song_data
|
219 |
+
|
220 |
+
# Set the instrument.
|
221 |
+
for track in song_data["tracks"]:
|
222 |
+
track["instrument"] = midi_instruments[instrument]
|
223 |
+
|
224 |
+
# Convert to piano roll.
|
225 |
+
piano_roll = convert_songdata_to_pianoroll(song_data)
|
226 |
+
st.session_state.piano_roll = piano_roll
|
227 |
+
|
228 |
+
# Convert to note sequence.
|
229 |
+
note_sequence = convert_songdata_to_notesequence(song_data)
|
230 |
+
st.session_state.note_sequence = note_sequence
|
231 |
+
|
232 |
+
# Play the note sequence.
|
233 |
+
wave = convert_notesequence_to_wave(note_sequence)
|
234 |
+
st.session_state.wave = wave
|
235 |
+
|
236 |
+
# Get the MIDI file content.
|
237 |
+
midi_file_content = convert_notesequence_to_midi(note_sequence)
|
238 |
+
st.session_state.midi_file_content = midi_file_content
|
239 |
+
|
240 |
+
# Rerun the app.
|
241 |
+
st.rerun()
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
main()
|
assets/asciilogo.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
▄█ █▄ ▄████████ ▄█ ▄█ ▀█████████▄ ▄████████ ███ █▄ ███▄▄▄▄ ███▄▄▄▄ ▄████████
|
2 |
+
███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███▀▀▀██▄ ███▀▀▀██▄ ███ ███
|
3 |
+
███ ███ ███ █▀ ███ ███▌ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███
|
4 |
+
▄███▄▄▄▄███▄▄ ▄███▄▄▄ ███ ███▌ ▄███▄▄▄██▀ ▄███▄▄▄▄██▀ ███ ███ ███ ███ ███ ███ ███ ███
|
5 |
+
▀▀███▀▀▀▀███▀ ▀▀███▀▀▀ ███ ███▌ ▀▀███▀▀▀██▄ ▀▀███▀▀▀▀▀ ███ ███ ███ ███ ███ ███ ▀███████████
|
6 |
+
███ ███ ███ █▄ ███ ███ ███ ██▄ ▀███████████ ███ ███ ███ ███ ███ ███ ███ ███
|
7 |
+
███ ███ ███ ███ ███▌ ▄ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███ ███
|
8 |
+
███ █▀ ██████████ █████▄▄██ █▀ ▄█████████▀ ███ ███ ████████▀ ▀█ █▀ ▀█ █▀ ███ █▀
|
9 |
+
▀ ███ ███
|
10 |
+
|
11 |
+
By Dr. Tristan Behrens
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dacite==1.8.1
|
2 |
+
colorama==0.4.6
|
3 |
+
omegaconf==2.3.0
|
4 |
+
streamlit==1.38.0
|
5 |
+
note_seq==0.0.5
|
6 |
+
pyfluidsynth==1.3.2
|
7 |
+
torch==2.2.0
|
8 |
+
transformers==4.44.0
|
9 |
+
mamba-ssm==2.2.2
|
10 |
+
einops==0.8.0
|
11 |
+
mambapy==1.2.0
|
source/languagemodel.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Helibrunna - A HuggingFace compatible xLSTM trainer.
|
2 |
+
# Copyright (c) 2024 Dr. Tristan Behrens
|
3 |
+
#
|
4 |
+
# This program is free software: you can redistribute it and/or modify
|
5 |
+
# it under the terms of the GNU Affero General Public License as published by
|
6 |
+
# the Free Software Foundation, either version 3 of the License, or
|
7 |
+
# (at your option) any later version.
|
8 |
+
#
|
9 |
+
# This program is distributed in the hope that it will be useful,
|
10 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
11 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
12 |
+
# GNU Affero General Public License for more details.
|
13 |
+
#
|
14 |
+
# You should have received a copy of the GNU Affero General Public License
|
15 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
16 |
+
|
17 |
+
import os
|
18 |
+
import glob
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
from transformers import PreTrainedTokenizerFast
|
21 |
+
import torch
|
22 |
+
from safetensors.torch import load_file
|
23 |
+
import time
|
24 |
+
from .utilities import display_logo, model_from_config
|
25 |
+
|
26 |
+
|
27 |
+
class LanguageModel:
|
28 |
+
|
29 |
+
def __init__(self, model_path_or_repo, config_overrides={}, mask_special_tokens=True, device="auto"):
|
30 |
+
"""
|
31 |
+
Initializes the LanguageModel object.
|
32 |
+
Args:
|
33 |
+
model_path_or_repo (str): The path to the model or the repository ID.
|
34 |
+
Raises:
|
35 |
+
ValueError: If the model checkpoint, tokenizer, config, or weights are not found.
|
36 |
+
Exception: If failed to download the model.
|
37 |
+
Returns:
|
38 |
+
None
|
39 |
+
"""
|
40 |
+
|
41 |
+
# Set the maskt_special_tokens flag.
|
42 |
+
self.mask_special_tokens = mask_special_tokens
|
43 |
+
|
44 |
+
# Set the device. CPU is default.
|
45 |
+
if device != "auto":
|
46 |
+
|
47 |
+
# Check if CUDA is available.
|
48 |
+
if not torch.cuda.is_available() and device == "cuda":
|
49 |
+
raise ValueError("CUDA is not available on this system.")
|
50 |
+
|
51 |
+
# Check if MPS is available.
|
52 |
+
if not torch.backends.mps.is_available() and device == "mps":
|
53 |
+
raise ValueError("MPS is not available on this system.")
|
54 |
+
|
55 |
+
# Set the device.
|
56 |
+
self.device = device
|
57 |
+
|
58 |
+
# Set the device to auto.
|
59 |
+
else:
|
60 |
+
|
61 |
+
# Set the device to CPU if auto is selected.
|
62 |
+
self.device = "cpu" if device == "auto" else device
|
63 |
+
|
64 |
+
# Check if CUDA is available.
|
65 |
+
if torch.cuda.is_available() and device == "auto":
|
66 |
+
self.device = "cuda"
|
67 |
+
|
68 |
+
# See if MPS is available.
|
69 |
+
# Note: This is disabled for now. It's not working as expected. It is very slow.
|
70 |
+
#if torch.backends.mps.is_available():
|
71 |
+
# self.device = "mps"
|
72 |
+
|
73 |
+
# Display the logo.
|
74 |
+
display_logo()
|
75 |
+
|
76 |
+
# Download the model if it doesn't exist. Or at least try to.
|
77 |
+
if not os.path.exists(model_path_or_repo):
|
78 |
+
from huggingface_hub import snapshot_download
|
79 |
+
try:
|
80 |
+
model_path=snapshot_download(repo_id=model_path_or_repo)
|
81 |
+
tokenizer_path=model_path
|
82 |
+
except Exception as e:
|
83 |
+
raise f"Failed to download the model: {e}"
|
84 |
+
|
85 |
+
# Use a local model.
|
86 |
+
else:
|
87 |
+
# Set the model path and tokenizer path.
|
88 |
+
model_path = None
|
89 |
+
tokenizer_path = model_path_or_repo
|
90 |
+
|
91 |
+
# Find all the checkpoint folders, folders that start with "checkpoint-". Then find the last one.
|
92 |
+
checkpoint_folders = glob.glob(os.path.join(model_path_or_repo, "checkpoint-*"))
|
93 |
+
for checkpoint_folder in checkpoint_folders:
|
94 |
+
if checkpoint_folder.endswith("-last"):
|
95 |
+
model_path = checkpoint_folder
|
96 |
+
break
|
97 |
+
if model_path is None:
|
98 |
+
raise ValueError("No model checkpoint found.")
|
99 |
+
|
100 |
+
# Find the tokenizer folder.
|
101 |
+
if os.path.exists(os.path.join(model_path_or_repo, "tokenizer.json")):
|
102 |
+
tokenizer_path = model_path_or_repo
|
103 |
+
if not os.path.exists(tokenizer_path):
|
104 |
+
raise ValueError("Tokenizer not found.")
|
105 |
+
|
106 |
+
# Load the config.
|
107 |
+
config_path = os.path.join(model_path, "config.yaml")
|
108 |
+
if not os.path.exists(config_path):
|
109 |
+
raise ValueError(f"Config not found at {config_path}")
|
110 |
+
model_config = OmegaConf.load(config_path)
|
111 |
+
|
112 |
+
# Override the config.
|
113 |
+
if config_overrides != {} and config_overrides is not None:
|
114 |
+
model_config = OmegaConf.merge(model_config, config_overrides)
|
115 |
+
import json
|
116 |
+
print(json.dumps(OmegaConf.to_container(model_config), indent=4))
|
117 |
+
|
118 |
+
# Create the model from the config.
|
119 |
+
model = model_from_config(model_config, device=self.device)
|
120 |
+
model.to(self.device)
|
121 |
+
self.config = model_config
|
122 |
+
|
123 |
+
# Load the weights from the checkpoint.
|
124 |
+
weights_path = os.path.join(model_path, "model.safetensors")
|
125 |
+
if not os.path.exists(weights_path):
|
126 |
+
raise ValueError(f"Weights not found at {weights_path}")
|
127 |
+
state_dict = load_file(weights_path)
|
128 |
+
|
129 |
+
# TODO: Permute the last two dimensions of these parameters: xlstm_block_stack.blocks.2.xlstm.slstm_cell._recurrent_kernel_:
|
130 |
+
# Check if we have an xLSTM model and if CUDA is not available.
|
131 |
+
if not torch.cuda.is_available() and model_config.get("type", "xLSTMLMModel") == "xLSTMLMModel":
|
132 |
+
print(state_dict.keys())
|
133 |
+
endings = ["xlstm.slstm_cell._recurrent_kernel_"]
|
134 |
+
for key, values in state_dict.items():
|
135 |
+
for ending in endings:
|
136 |
+
if key.endswith(ending):
|
137 |
+
print(key)
|
138 |
+
print(values.shape)
|
139 |
+
|
140 |
+
# Option: Permute the last two dimensions.
|
141 |
+
values = values.permute(0, 2, 1)
|
142 |
+
|
143 |
+
# Option: View the tensor.
|
144 |
+
#new_shape = (values.shape[0], values.shape[2], values.shape[1])
|
145 |
+
#values = values.view(new_shape)
|
146 |
+
|
147 |
+
print(values.shape)
|
148 |
+
state_dict[key] = values
|
149 |
+
break
|
150 |
+
|
151 |
+
# Load the weights into the model.
|
152 |
+
model.load_state_dict(state_dict)
|
153 |
+
self.model = model
|
154 |
+
|
155 |
+
# Load the tokenizer.
|
156 |
+
tokenizer_path = os.path.join(tokenizer_path, "tokenizer.json")
|
157 |
+
if not os.path.exists(tokenizer_path):
|
158 |
+
raise ValueError(f"Tokenizer not found at {tokenizer_path}")
|
159 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
|
160 |
+
self.tokenizer = tokenizer
|
161 |
+
|
162 |
+
|
163 |
+
def generate(
|
164 |
+
self,
|
165 |
+
prompt: str,
|
166 |
+
temperature: float = 1.0,
|
167 |
+
max_length: int = 100,
|
168 |
+
end_tokens: list[str] = [],
|
169 |
+
forbidden_tokens: list[str] = [],
|
170 |
+
return_structured_output: bool = False
|
171 |
+
):
|
172 |
+
"""
|
173 |
+
Generates a continuation for a given prompt using the language model.
|
174 |
+
Args:
|
175 |
+
prompt (str): The prompt to generate a continuation for.
|
176 |
+
temperature (float, optional): The temperature value for controlling the randomness of the generated output.
|
177 |
+
Higher values (e.g., 1.0) make the output more random, while lower values (e.g., 0.5) make it more deterministic.
|
178 |
+
Defaults to 1.0.
|
179 |
+
max_length (int, optional): The maximum length of the generated output. Defaults to 100.
|
180 |
+
end_tokens (list[str], optional): A list of end tokens that, if encountered, will stop the generation process.
|
181 |
+
Defaults to an empty list.
|
182 |
+
return_structured_output (bool, optional): If True, returns a dictionary with the generated output, elapsed time,
|
183 |
+
and tokens per second. If False, returns only the generated output as a string. Defaults to False.
|
184 |
+
Returns:
|
185 |
+
str or dict: The generated output as a string if return_structured_output is False.
|
186 |
+
A dictionary with the generated output, elapsed time, and tokens per second if return_structured_output is True.
|
187 |
+
"""
|
188 |
+
|
189 |
+
# Tokenize the prompt.
|
190 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
191 |
+
assert inputs.shape[0] == 1
|
192 |
+
|
193 |
+
# Determine the end tokens ids.
|
194 |
+
end_token_ids = []
|
195 |
+
for end_token in end_tokens:
|
196 |
+
assert end_token in self.tokenizer.vocab
|
197 |
+
end_token_ids.append(self.tokenizer(end_token).input_ids[0])
|
198 |
+
|
199 |
+
# Initialize the ids to mask.
|
200 |
+
ids_to_mask = []
|
201 |
+
|
202 |
+
# Mask the forbidden tokens.
|
203 |
+
for forbidden_token in forbidden_tokens:
|
204 |
+
assert forbidden_token in self.tokenizer.vocab
|
205 |
+
ids_to_mask.extend(self.tokenizer(forbidden_token).input_ids)
|
206 |
+
|
207 |
+
# Generate the continuation.
|
208 |
+
start_time = time.time()
|
209 |
+
tokens_count = 0
|
210 |
+
while inputs.shape[1] < max_length:
|
211 |
+
|
212 |
+
# Stop if the maximum context length is reached.
|
213 |
+
if inputs.shape[1] >= self.config.context_length:
|
214 |
+
print("Warning: The maximum context length has been reached.")
|
215 |
+
break
|
216 |
+
|
217 |
+
# Generate the continuation.
|
218 |
+
outputs = self.model(inputs.to(device=self.device))
|
219 |
+
assert outputs.shape[0] == 1
|
220 |
+
|
221 |
+
# Mask the tokens.
|
222 |
+
outputs[:, :, self.tokenizer.all_special_ids] = float("-inf")
|
223 |
+
|
224 |
+
# Use the temperature to sample from the distribution.
|
225 |
+
outputs = outputs / temperature
|
226 |
+
outputs = torch.nn.functional.softmax(outputs, dim=-1)
|
227 |
+
outputs = torch.multinomial(outputs[0, -1], num_samples=1)
|
228 |
+
|
229 |
+
# Add to the inputs.
|
230 |
+
inputs = torch.cat([inputs, outputs.unsqueeze(0)], dim=1)
|
231 |
+
|
232 |
+
# Increment the tokens count.
|
233 |
+
tokens_count += 1
|
234 |
+
|
235 |
+
# Check if the end token is reached.
|
236 |
+
if outputs[0] in end_token_ids:
|
237 |
+
break
|
238 |
+
|
239 |
+
# Print the elapsed time and tokens per second.
|
240 |
+
elapsed_time = time.time() - start_time
|
241 |
+
tokens_per_second = tokens_count / elapsed_time
|
242 |
+
|
243 |
+
# Decode the output.
|
244 |
+
output = self.tokenizer.decode(inputs[0].tolist())
|
245 |
+
|
246 |
+
# Return the output.
|
247 |
+
if not return_structured_output:
|
248 |
+
return output
|
249 |
+
|
250 |
+
# Return the structured output.
|
251 |
+
else:
|
252 |
+
return {
|
253 |
+
"output": output,
|
254 |
+
"elapsed_time": elapsed_time,
|
255 |
+
"tokens_per_second": tokens_per_second
|
256 |
+
}
|
257 |
+
|
258 |
+
def summary(self):
|
259 |
+
"""
|
260 |
+
Prints a summary of the model. Makes the model architecture readable. Includes the number of parameters.
|
261 |
+
"""
|
262 |
+
|
263 |
+
# Print the model.
|
264 |
+
print(self.model)
|
265 |
+
|
266 |
+
# Get the number of parameters.
|
267 |
+
number_of_parameters = sum(p.numel() for p in self.model.parameters())
|
268 |
+
print(f"Number of parameters: {number_of_parameters:_}")
|
269 |
+
sizes = ["", "K", "M", "B", "T"]
|
270 |
+
size_index = 0
|
271 |
+
while number_of_parameters > 1000:
|
272 |
+
number_of_parameters /= 1000
|
273 |
+
size_index += 1
|
274 |
+
print(f"Number of parameters: {number_of_parameters:.2f}{sizes[size_index]}")
|
275 |
+
|
276 |
+
# Size of the model.
|
277 |
+
# Get the total size of all the markdown files. And make it human readable.
|
278 |
+
number_of_parameters = sum(p.numel() for p in self.model.parameters())
|
279 |
+
total_size = number_of_parameters * 4
|
280 |
+
sizes = ["B", "KB", "MB", "GB", "TB"]
|
281 |
+
size_index = 0
|
282 |
+
while total_size > 1024:
|
283 |
+
total_size /= 1024
|
284 |
+
size_index += 1
|
285 |
+
print(f"Total size of the model: {total_size:.2f}{sizes[size_index]} for precision 32-bit floats.")
|
286 |
+
|
287 |
+
# Print on which device the model is running.
|
288 |
+
print(f"Device: {self.device}")
|
source/utilities.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import note_seq
|
3 |
+
from PIL import Image
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
import colorama
|
7 |
+
from omegaconf import DictConfig, OmegaConf
|
8 |
+
import torch
|
9 |
+
from typing import List, Tuple, Dict
|
10 |
+
from dacite import from_dict
|
11 |
+
from collections.abc import MutableMapping
|
12 |
+
import sys
|
13 |
+
|
14 |
+
|
15 |
+
# NOTE: Imported from helibrunna.
|
16 |
+
def display_logo():
|
17 |
+
"""
|
18 |
+
Display the logo by printing it line by line with a cyberpunk color scheme.
|
19 |
+
|
20 |
+
Raises:
|
21 |
+
FileNotFoundError: If the logo file is missing.
|
22 |
+
"""
|
23 |
+
|
24 |
+
# Get the path of this script and use it to find the logo.
|
25 |
+
script_path = os.path.dirname(os.path.realpath(__file__))
|
26 |
+
search_path = os.path.dirname(script_path)
|
27 |
+
|
28 |
+
# Load the logo.
|
29 |
+
logo_path = os.path.join(search_path, "assets", "asciilogo.txt")
|
30 |
+
if not os.path.exists(logo_path):
|
31 |
+
raise FileNotFoundError("The logo file is missing.")
|
32 |
+
with open(logo_path, "r") as f:
|
33 |
+
logo = f.read()
|
34 |
+
|
35 |
+
# Print the logo line by line. Use colorama to colorize the output. Use a cyberpunk color scheme.
|
36 |
+
for line_index, line in enumerate(logo.split("\n")):
|
37 |
+
color = colorama.Fore.GREEN
|
38 |
+
style = colorama.Style.BRIGHT if line_index % 2 == 0 else colorama.Style.NORMAL
|
39 |
+
print(color + style + line)
|
40 |
+
print(colorama.Style.RESET_ALL)
|
41 |
+
|
42 |
+
|
43 |
+
# NOTE: Imported from helibrunna.
|
44 |
+
def model_from_config(model_config: DictConfig, device:str) -> torch.nn.Module:
|
45 |
+
"""
|
46 |
+
Create a model based on the provided model configuration.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
model_config (DictConfig): The configuration for the model.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
The created model.
|
53 |
+
|
54 |
+
Raises:
|
55 |
+
ValueError: If the model type is unknown.
|
56 |
+
"""
|
57 |
+
|
58 |
+
# Get the model type from the configuration.
|
59 |
+
model_type = model_config.get("type", "xLSTMLMModel")
|
60 |
+
|
61 |
+
# Create the xLSTMLMModel.
|
62 |
+
if model_type == "xLSTMLMModel":
|
63 |
+
print("Creating xLSTMLMModel...")
|
64 |
+
from xlstm.xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
|
65 |
+
|
66 |
+
# If there is no GPU, use the vanilla backend.
|
67 |
+
if not torch.cuda.is_available():
|
68 |
+
#model_config.backend = "vanilla"
|
69 |
+
model_config.slstm_block.slstm.backend = "vanilla"
|
70 |
+
model_config.mlstm_block.mlstm.backend = "vanilla"
|
71 |
+
model_config_object = from_dict(xLSTMLMModelConfig, OmegaConf.to_container(model_config))
|
72 |
+
|
73 |
+
# Create the model.
|
74 |
+
model = xLSTMLMModel(model_config_object)
|
75 |
+
model.reset_parameters()
|
76 |
+
|
77 |
+
# Create the GPT2LMModel.
|
78 |
+
elif model_type == "gpt2":
|
79 |
+
print("Creating GPT2LMModel...")
|
80 |
+
from .models.gpttwo import GPT2LMModel, GPT2LMModelConfig
|
81 |
+
model_config_object = from_dict(GPT2LMModelConfig, OmegaConf.to_container(model_config))
|
82 |
+
model = GPT2LMModel(model_config_object)
|
83 |
+
|
84 |
+
# Create the MambaLM.
|
85 |
+
elif model_type == "mamba":
|
86 |
+
print("Creating Mamba LM...")
|
87 |
+
from mambapy.lm import LM, MambaConfig
|
88 |
+
model_config_object = from_dict(MambaConfig, OmegaConf.to_container(model_config))
|
89 |
+
model = LM(model_config_object, model_config.vocab_size)
|
90 |
+
|
91 |
+
# Create the Transformer.
|
92 |
+
elif model_type == "transformer":
|
93 |
+
from .models.transformer import TransformerConfig, Transformer
|
94 |
+
model_config_object = from_dict(TransformerConfig, OmegaConf.to_container(model_config))
|
95 |
+
model = Transformer(model_config_object)
|
96 |
+
|
97 |
+
# Create a Pharia instance.
|
98 |
+
elif model_type == "pharia":
|
99 |
+
from .models.pharia import PhariaConfig, PhariaModel
|
100 |
+
model_config_object = from_dict(PhariaConfig, OmegaConf.to_container(model_config))
|
101 |
+
model = PhariaModel(model_config_object)
|
102 |
+
|
103 |
+
# Create a TransformerXL instance.
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
106 |
+
|
107 |
+
# Move the model to the device.
|
108 |
+
model.to(device)
|
109 |
+
return model
|
110 |
+
|
111 |
+
|
112 |
+
def convert_tokens_to_songdata(tokens):
|
113 |
+
|
114 |
+
if isinstance(tokens, str):
|
115 |
+
tokens = tokens.split()
|
116 |
+
|
117 |
+
song_data = {}
|
118 |
+
|
119 |
+
song_data["tracks"] = []
|
120 |
+
|
121 |
+
current_track_index = 0
|
122 |
+
current_timestep = 0
|
123 |
+
for token in tokens:
|
124 |
+
if token == "GARLAND_START":
|
125 |
+
pass
|
126 |
+
elif token == "BAR_START":
|
127 |
+
if current_track_index == len(song_data["tracks"]):
|
128 |
+
song_data["tracks"] += [{"bars": [], "instrument": "0"}]
|
129 |
+
bar_data = {"notes": []}
|
130 |
+
song_data["tracks"][current_track_index]["bars"] += [bar_data]
|
131 |
+
current_timestep = 0
|
132 |
+
elif token.startswith("INST="):
|
133 |
+
instrument = token.split("=")[1]
|
134 |
+
song_data["tracks"][current_track_index]["instrument"] = instrument
|
135 |
+
elif token.startswith("DENSITY="):
|
136 |
+
pass
|
137 |
+
elif token.startswith("NOTE_ON="):
|
138 |
+
note_pitch = int(token.split("=")[1])
|
139 |
+
note_data = {
|
140 |
+
"note": note_pitch,
|
141 |
+
"start": current_timestep,
|
142 |
+
"end": current_timestep,
|
143 |
+
"veloctiy": 80
|
144 |
+
}
|
145 |
+
song_data["tracks"][current_track_index]["bars"][-1]["notes"] += [note_data]
|
146 |
+
pass
|
147 |
+
elif token.startswith("TIME_DELTA="):
|
148 |
+
current_timestep += int(token.split("=")[1])
|
149 |
+
elif token.startswith("NOTE_OFF="):
|
150 |
+
note_pitch = int(token.split("=")[1])
|
151 |
+
for note_data in song_data["tracks"][current_track_index]["bars"][-1]["notes"]:
|
152 |
+
if note_data["note"] == note_pitch and note_data["start"] == note_data["end"]:
|
153 |
+
note_data["end"] = current_timestep
|
154 |
+
break
|
155 |
+
pass
|
156 |
+
elif token == "BAR_END":
|
157 |
+
current_track_index += 1
|
158 |
+
elif token == "NEXT":
|
159 |
+
current_track_index = 0
|
160 |
+
elif token == "GARLAND_END":
|
161 |
+
pass
|
162 |
+
elif token == "[PAD]":
|
163 |
+
pass
|
164 |
+
elif token == "[EOS]":
|
165 |
+
pass
|
166 |
+
else:
|
167 |
+
raise Exception(f"Unknown token: {token}")
|
168 |
+
|
169 |
+
assert isinstance(song_data, dict)
|
170 |
+
return song_data
|
171 |
+
|
172 |
+
|
173 |
+
def convert_songdata_to_notesequence(song_data:dict, quantize_steps_per_quarter=8, remove_disabled_tracks=True):
|
174 |
+
|
175 |
+
assert isinstance(song_data, dict), f"Invalid song data type: {type(song_data)}"
|
176 |
+
|
177 |
+
# Clone the song data.
|
178 |
+
song_data = copy.deepcopy(song_data)
|
179 |
+
|
180 |
+
# Sort the tracks by instrument.
|
181 |
+
assert "tracks" in song_data, f"Invalid song data: {song_data.keys()}"
|
182 |
+
tracks = sorted(song_data["tracks"], key=lambda t: t["instrument"])
|
183 |
+
song_data["tracks"] = tracks
|
184 |
+
|
185 |
+
# Remove tracks that are not enabled.
|
186 |
+
if remove_disabled_tracks:
|
187 |
+
song_data["tracks"] = [t for t in song_data["tracks"] if t.get("enabled", True)]
|
188 |
+
|
189 |
+
# Create an empy note sequence.
|
190 |
+
note_sequence = note_seq.protobuf.music_pb2.NoteSequence()
|
191 |
+
|
192 |
+
# Add the tempo.
|
193 |
+
bpm = song_data["bpm"] if "bpm" in song_data else 120
|
194 |
+
note_sequence.tempos.add().qpm = bpm
|
195 |
+
|
196 |
+
# Compute some lengths.
|
197 |
+
step_length_seconds = 60.0 / bpm / quantize_steps_per_quarter
|
198 |
+
bar_length_seconds = 4 * step_length_seconds * quantize_steps_per_quarter
|
199 |
+
|
200 |
+
# Get the instruments.
|
201 |
+
instruments = list(set([t["instrument"] for t in song_data["tracks"]]))
|
202 |
+
|
203 |
+
# Add the tracks.
|
204 |
+
for track_index, track_data in enumerate(song_data["tracks"]):
|
205 |
+
instrument = track_data["instrument"]
|
206 |
+
for bar_index, bar_data in enumerate(track_data["bars"]):
|
207 |
+
bar_start_time = bar_index * bar_length_seconds
|
208 |
+
for note_data in bar_data["notes"]:
|
209 |
+
assert "note" in note_data
|
210 |
+
assert "start" in note_data
|
211 |
+
assert "end" in note_data
|
212 |
+
note = note_sequence.notes.add()
|
213 |
+
#note.instrument = instrument TODO
|
214 |
+
note.pitch = note_data["note"]
|
215 |
+
note.start_time = note_data["start"] * step_length_seconds + bar_start_time
|
216 |
+
note.end_time = note_data["end"] * step_length_seconds + bar_start_time
|
217 |
+
if "velocity" in note_data:
|
218 |
+
note.velocity = note_data["velocity"]
|
219 |
+
else:
|
220 |
+
note.velocity = 80
|
221 |
+
note.instrument = track_index
|
222 |
+
if instrument == "drums":
|
223 |
+
note.is_drum = True
|
224 |
+
else:
|
225 |
+
note.is_drum = False
|
226 |
+
note.program = int(instrument)
|
227 |
+
|
228 |
+
return note_sequence
|
229 |
+
|
230 |
+
|
231 |
+
def convert_songdata_to_pianoroll(song_data):
|
232 |
+
|
233 |
+
# The bars are 4/4 and the quantization is 8 steps per quarter, aka 32 steps per bar.
|
234 |
+
# We will render a grid. The height is 64 pixels. The width is 32 pixels per bar
|
235 |
+
|
236 |
+
# Create a new image.
|
237 |
+
lengths = [len(track["bars"]) for track in song_data["tracks"]]
|
238 |
+
if lengths == []:
|
239 |
+
return None
|
240 |
+
assert len(set(lengths)) == 1, f"Unequal number of bars: {lengths}"
|
241 |
+
num_bars = lengths[0]
|
242 |
+
|
243 |
+
# Get the note extremes.
|
244 |
+
min_note = 128
|
245 |
+
max_note = 0
|
246 |
+
for track_data in song_data["tracks"]:
|
247 |
+
for bar_data in track_data["bars"]:
|
248 |
+
for note_data in bar_data["notes"]:
|
249 |
+
min_note = min(min_note, note_data["note"])
|
250 |
+
max_note = max(max_note, note_data["note"])
|
251 |
+
|
252 |
+
# The width depends on the bars.
|
253 |
+
width = 32 * num_bars
|
254 |
+
|
255 |
+
# The width depends on the notes.
|
256 |
+
height = 1 + max_note - min_note
|
257 |
+
|
258 |
+
# Create the image.
|
259 |
+
image = Image.new("RGB", (width, height), (14, 17, 23))
|
260 |
+
|
261 |
+
# Define some colors.
|
262 |
+
base_color = (255, 75, 75)
|
263 |
+
adjustments = [1.2, 1.0, 0.8, 0.6]
|
264 |
+
colors = []
|
265 |
+
for adjustment in adjustments:
|
266 |
+
import colorsys
|
267 |
+
rgb = base_color
|
268 |
+
rgb = [float(c) / 255.0 for c in rgb]
|
269 |
+
hsv = colorsys.rgb_to_hsv(*rgb)
|
270 |
+
# Rotate the hue.
|
271 |
+
offset = (adjustment - 1.0) * 0.1
|
272 |
+
hsv = (hsv[0] + offset, hsv[1], hsv[2])
|
273 |
+
rgb = colorsys.hsv_to_rgb(*hsv)
|
274 |
+
rgb = tuple([int(255.0 * c) for c in rgb])
|
275 |
+
colors += [rgb]
|
276 |
+
print("")
|
277 |
+
|
278 |
+
for color in colors:
|
279 |
+
print(color)
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
# Draw the grid.
|
284 |
+
for track_index, track_data in enumerate(song_data["tracks"]):
|
285 |
+
color = colors[track_index % len(colors)]
|
286 |
+
for bar_index, bar_data in enumerate(track_data["bars"]):
|
287 |
+
x = bar_index * 32
|
288 |
+
|
289 |
+
for note_data in bar_data["notes"]:
|
290 |
+
y = max_note - note_data["note"]
|
291 |
+
assert y >= 0 and y < height, f"Invalid y: {y}, note {note_data['note']} min_note: {min_note}, max_note: {max_note}, difference: {max_note - min_note}, height: {height}"
|
292 |
+
for i in range(note_data["start"], note_data["end"]):
|
293 |
+
image.putpixel((x + i, y), color)
|
294 |
+
|
295 |
+
# Resize the image. Use nearest neighbor for pixel art.
|
296 |
+
factor = 4
|
297 |
+
image = image.resize((width * factor, height * factor), Image.NEAREST)
|
298 |
+
|
299 |
+
return image
|
300 |
+
|
301 |
+
|
302 |
+
def convert_notesequence_to_wave(note_sequence):
|
303 |
+
|
304 |
+
if len(note_sequence.notes) == 0:
|
305 |
+
return None
|
306 |
+
|
307 |
+
try:
|
308 |
+
synthesizer = note_seq.fluidsynth
|
309 |
+
wave = synthesizer(note_sequence, sample_rate=44100)
|
310 |
+
return wave
|
311 |
+
except Exception as e:
|
312 |
+
synthesizer = note_seq.synthesize
|
313 |
+
wave = synthesizer(note_sequence)
|
314 |
+
return wave
|
315 |
+
|
316 |
+
|
317 |
+
def convert_notesequence_to_midi(note_sequence, filename="output.mid"):
|
318 |
+
|
319 |
+
if len(note_sequence.notes) == 0:
|
320 |
+
return None
|
321 |
+
|
322 |
+
# Returns the file content of the midi file.
|
323 |
+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
324 |
+
filename = temp_file.name
|
325 |
+
note_seq.sequence_proto_to_midi_file(note_sequence, filename)
|
326 |
+
with open(filename, "rb") as file:
|
327 |
+
content = file.read()
|
328 |
+
return content
|
329 |
+
|
330 |
+
|
331 |
+
|