Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
import streamlit as st
|
6 |
+
import lyricsgenius
|
7 |
+
import transformers
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
st.set_page_config(page_title="HuggingArtists")
|
13 |
+
|
14 |
+
|
15 |
+
st.title("HuggingArtists")
|
16 |
+
st.sidebar.markdown(
|
17 |
+
"""
|
18 |
+
<style>
|
19 |
+
.aligncenter {
|
20 |
+
text-align: center;
|
21 |
+
}
|
22 |
+
</style>
|
23 |
+
<p class="aligncenter">
|
24 |
+
<img src="https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/master/img/logo.jpg" width="420" />
|
25 |
+
</p>
|
26 |
+
""",
|
27 |
+
unsafe_allow_html=True,
|
28 |
+
)
|
29 |
+
st.sidebar.markdown(
|
30 |
+
"""
|
31 |
+
<style>
|
32 |
+
.aligncenter {
|
33 |
+
text-align: center;
|
34 |
+
}
|
35 |
+
</style>
|
36 |
+
|
37 |
+
<p style='text-align: center'>
|
38 |
+
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">GitHub</a> | <a href="https://wandb.ai/huggingartists/huggingartists/reportlist" target="_blank">Project Report</a>
|
39 |
+
</p>
|
40 |
+
|
41 |
+
<p class="aligncenter">
|
42 |
+
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">
|
43 |
+
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/>
|
44 |
+
</a>
|
45 |
+
</p>
|
46 |
+
<p class="aligncenter">
|
47 |
+
<a href="https://t.me/joinchat/_CQ04KjcJ-4yZTky" target="_blank">
|
48 |
+
<img src="https://img.shields.io/badge/dynamic/json?color=blue&label=Telegram%20Channel&query=%24.result&url=https%3A%2F%2Fapi.telegram.org%2Fbot1929545866%3AAAFGhV-KKnegEcLiyYJxsc4zV6C-bdPEBtQ%2FgetChatMemberCount%3Fchat_id%3D-1001253621662&style=social&logo=telegram"/>
|
49 |
+
</a>
|
50 |
+
</p>
|
51 |
+
<p class="aligncenter">
|
52 |
+
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank">
|
53 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg"/>
|
54 |
+
</a>
|
55 |
+
</p>
|
56 |
+
""",
|
57 |
+
unsafe_allow_html=True,
|
58 |
+
)
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
st.sidebar.header("Generation settings:")
|
63 |
+
num_sequences = st.sidebar.number_input(
|
64 |
+
"Number of sequences to generate",
|
65 |
+
min_value=1,
|
66 |
+
value=5,
|
67 |
+
help="The amount of generated texts",
|
68 |
+
)
|
69 |
+
min_length = st.sidebar.number_input(
|
70 |
+
"Minimum length of the sequence",
|
71 |
+
min_value=1,
|
72 |
+
value=100,
|
73 |
+
help="The minimum length of the sequence to be generated",
|
74 |
+
)
|
75 |
+
max_length= st.sidebar.number_input(
|
76 |
+
"Maximum length of the sequence",
|
77 |
+
min_value=1,
|
78 |
+
value=160,
|
79 |
+
help="The maximum length of the sequence to be generated",
|
80 |
+
)
|
81 |
+
temperature = st.sidebar.slider(
|
82 |
+
"Temperature",
|
83 |
+
min_value=0.0,
|
84 |
+
max_value=3.0,
|
85 |
+
step=0.01,
|
86 |
+
value=1.0,
|
87 |
+
help="The value used to module the next token probabilities",
|
88 |
+
)
|
89 |
+
top_p = st.sidebar.slider(
|
90 |
+
"Top-P",
|
91 |
+
min_value=0.0,
|
92 |
+
max_value=1.0,
|
93 |
+
step=0.01,
|
94 |
+
value=0.95,
|
95 |
+
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
|
96 |
+
)
|
97 |
+
|
98 |
+
top_k= st.sidebar.number_input(
|
99 |
+
"Top-K",
|
100 |
+
min_value=0,
|
101 |
+
value=50,
|
102 |
+
step=1,
|
103 |
+
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
104 |
+
)
|
105 |
+
|
106 |
+
caption = (
|
107 |
+
"In [HuggingArtists](https://github.com/AlekseyKorshuk/huggingartist), we can generate lyrics by a specific artist. This was made by fine-tuning a pre-trained [HuggingFace Transformer](https://huggingface.co) on parsed datasets from [Genius](https://genius.com)."
|
108 |
+
)
|
109 |
+
st.markdown("[HuggingArtists](https://github.com/AlekseyKorshuk/huggingartist) - Train a model to generate lyrics π΅")
|
110 |
+
st.markdown(caption)
|
111 |
+
|
112 |
+
st.subheader("Settings:")
|
113 |
+
artist_name = st.text_input("Artist name:", "Eminem")
|
114 |
+
start = st.text_input("Beginning of the song:", "But for me to rap like a computer")
|
115 |
+
|
116 |
+
TOKEN = "q_JK_BFy9OMiG7fGTzL-nUto9JDv3iXI24aYRrQnkOvjSCSbY4BuFIindweRsr5I"
|
117 |
+
genius = lyricsgenius.Genius(TOKEN)
|
118 |
+
|
119 |
+
model_html = """
|
120 |
+
|
121 |
+
<div class="inline-flex flex-col" style="line-height: 1.5;">
|
122 |
+
<div class="flex">
|
123 |
+
<div
|
124 |
+
\t\t\tstyle="display:DISPLAY_1; margin-left: auto; margin-right: auto; width: 92px; height:92px; border-radius: 50%; background-size: cover; background-image: url('USER_PROFILE')">
|
125 |
+
</div>
|
126 |
+
</div>
|
127 |
+
<div style="text-align: center; margin-top: 3px; font-size: 16px; font-weight: 800">π€ HuggingArtists Model π€</div>
|
128 |
+
<div style="text-align: center; font-size: 16px; font-weight: 800">USER_NAME</div>
|
129 |
+
<a href="https://genius.com/artists/USER_HANDLE">
|
130 |
+
\t<div style="text-align: center; font-size: 14px;">@USER_HANDLE</div>
|
131 |
+
</a>
|
132 |
+
</div>
|
133 |
+
"""
|
134 |
+
|
135 |
+
|
136 |
+
def post_process(output_sequences):
|
137 |
+
predictions = []
|
138 |
+
generated_sequences = []
|
139 |
+
|
140 |
+
max_repeat = 2
|
141 |
+
|
142 |
+
# decode prediction
|
143 |
+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
144 |
+
generated_sequence = generated_sequence.tolist()
|
145 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
146 |
+
generated_sequences.append(text.strip())
|
147 |
+
|
148 |
+
for i, g in enumerate(generated_sequences):
|
149 |
+
res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
|
150 |
+
lines = res.split('\n')
|
151 |
+
# print(lines)
|
152 |
+
# i = max_repeat
|
153 |
+
# while i != len(lines):
|
154 |
+
# remove_count = 0
|
155 |
+
# for index in range(0, max_repeat):
|
156 |
+
# # print(i - index - 1, i - index)
|
157 |
+
# if lines[i - index - 1] == lines[i - index]:
|
158 |
+
# remove_count += 1
|
159 |
+
# if remove_count == max_repeat:
|
160 |
+
# lines.pop(i)
|
161 |
+
# i -= 1
|
162 |
+
# else:
|
163 |
+
# i += 1
|
164 |
+
predictions.append('\n'.join(lines))
|
165 |
+
|
166 |
+
return predictions
|
167 |
+
|
168 |
+
if st.button("Run"):
|
169 |
+
model_name = None
|
170 |
+
with st.spinner(text=f"Searching for {artist_name } in Genius..."):
|
171 |
+
artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False)
|
172 |
+
if artist is not None:
|
173 |
+
artist_dict = genius.artist(artist.id)['artist']
|
174 |
+
artist_url = str(artist_dict['url'])
|
175 |
+
model_name = artist_url[artist_url.rfind('/') + 1:].lower()
|
176 |
+
st.markdown(model_html.replace("USER_PROFILE",artist.image_url).replace("USER_NAME",artist.name).replace("USER_HANDLE",model_name), unsafe_allow_html=True)
|
177 |
+
else:
|
178 |
+
st.markdown(f"Could not find {artist_name}! Be sure that he/she exists in [Genius](https://genius.com/).")
|
179 |
+
if model_name is not None:
|
180 |
+
with st.spinner(text=f"Downloading the model of {artist_name }..."):
|
181 |
+
model = None
|
182 |
+
tokenizer = None
|
183 |
+
try:
|
184 |
+
tokenizer = AutoTokenizer.from_pretrained(f"huggingartists/{model_name}")
|
185 |
+
model = AutoModelForCausalLM.from_pretrained(f"huggingartists/{model_name}")
|
186 |
+
except Exception as ex:
|
187 |
+
# st.markdown(ex)
|
188 |
+
st.markdown(f"Model for this artist does not exist yet. Create it in just 5 min with [Colab Notebook](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb):")
|
189 |
+
st.markdown(
|
190 |
+
"""
|
191 |
+
<style>
|
192 |
+
.aligncenter {
|
193 |
+
text-align: center;
|
194 |
+
}
|
195 |
+
</style>
|
196 |
+
<p class="aligncenter">
|
197 |
+
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank">
|
198 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg"/>
|
199 |
+
</a>
|
200 |
+
</p>
|
201 |
+
""",
|
202 |
+
unsafe_allow_html=True,
|
203 |
+
)
|
204 |
+
if model is not None:
|
205 |
+
with st.spinner(text=f"Generating lyrics..."):
|
206 |
+
encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids
|
207 |
+
encoded_prompt = encoded_prompt.to(model.device)
|
208 |
+
# prediction
|
209 |
+
output_sequences = model.generate(
|
210 |
+
input_ids=encoded_prompt,
|
211 |
+
max_length=max_length,
|
212 |
+
min_length=min_length,
|
213 |
+
temperature=float(temperature),
|
214 |
+
top_p=float(top_p),
|
215 |
+
top_k=int(top_k),
|
216 |
+
do_sample=True,
|
217 |
+
repetition_penalty=1.0,
|
218 |
+
num_return_sequences=num_sequences
|
219 |
+
)
|
220 |
+
# Post-processing
|
221 |
+
predictions = post_process(output_sequences)
|
222 |
+
st.subheader("Results")
|
223 |
+
for prediction in predictions:
|
224 |
+
st.text(prediction)
|
225 |
+
st.subheader("Please star this repository and join my Telegram Channel:")
|
226 |
+
st.markdown(
|
227 |
+
"""
|
228 |
+
<style>
|
229 |
+
.aligncenter {
|
230 |
+
text-align: center;
|
231 |
+
}
|
232 |
+
</style>
|
233 |
+
<p class="aligncenter">
|
234 |
+
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">
|
235 |
+
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/>
|
236 |
+
</a>
|
237 |
+
</p>
|
238 |
+
<p class="aligncenter">
|
239 |
+
<a href="https://t.me/joinchat/_CQ04KjcJ-4yZTky" target="_blank">
|
240 |
+
<img src="https://img.shields.io/badge/dynamic/json?color=blue&label=Telegram%20Channel&query=%24.result&url=https%3A%2F%2Fapi.telegram.org%2Fbot1929545866%3AAAFGhV-KKnegEcLiyYJxsc4zV6C-bdPEBtQ%2FgetChatMemberCount%3Fchat_id%3D-1001253621662&style=social&logo=telegram"/>
|
241 |
+
</a>
|
242 |
+
</p>
|
243 |
+
""",
|
244 |
+
unsafe_allow_html=True,
|
245 |
+
)
|