hgw3lss commited on
Commit
09ab366
Β·
1 Parent(s): 108539a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import random
4
+ import os
5
+ import streamlit as st
6
+ import lyricsgenius
7
+ import transformers
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+
10
+
11
+
12
+ st.set_page_config(page_title="HuggingArtists")
13
+
14
+
15
+ st.title("HuggingArtists")
16
+ st.sidebar.markdown(
17
+ """
18
+ <style>
19
+ .aligncenter {
20
+ text-align: center;
21
+ }
22
+ </style>
23
+ <p class="aligncenter">
24
+ <img src="https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/master/img/logo.jpg" width="420" />
25
+ </p>
26
+ """,
27
+ unsafe_allow_html=True,
28
+ )
29
+ st.sidebar.markdown(
30
+ """
31
+ <style>
32
+ .aligncenter {
33
+ text-align: center;
34
+ }
35
+ </style>
36
+
37
+ <p style='text-align: center'>
38
+ <a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">GitHub</a> | <a href="https://wandb.ai/huggingartists/huggingartists/reportlist" target="_blank">Project Report</a>
39
+ </p>
40
+
41
+ <p class="aligncenter">
42
+ <a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">
43
+ <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/>
44
+ </a>
45
+ </p>
46
+ <p class="aligncenter">
47
+ <a href="https://t.me/joinchat/_CQ04KjcJ-4yZTky" target="_blank">
48
+ <img src="https://img.shields.io/badge/dynamic/json?color=blue&label=Telegram%20Channel&query=%24.result&url=https%3A%2F%2Fapi.telegram.org%2Fbot1929545866%3AAAFGhV-KKnegEcLiyYJxsc4zV6C-bdPEBtQ%2FgetChatMemberCount%3Fchat_id%3D-1001253621662&style=social&logo=telegram"/>
49
+ </a>
50
+ </p>
51
+ <p class="aligncenter">
52
+ <a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank">
53
+ <img src="https://colab.research.google.com/assets/colab-badge.svg"/>
54
+ </a>
55
+ </p>
56
+ """,
57
+ unsafe_allow_html=True,
58
+ )
59
+
60
+
61
+
62
+ st.sidebar.header("Generation settings:")
63
+ num_sequences = st.sidebar.number_input(
64
+ "Number of sequences to generate",
65
+ min_value=1,
66
+ value=5,
67
+ help="The amount of generated texts",
68
+ )
69
+ min_length = st.sidebar.number_input(
70
+ "Minimum length of the sequence",
71
+ min_value=1,
72
+ value=100,
73
+ help="The minimum length of the sequence to be generated",
74
+ )
75
+ max_length= st.sidebar.number_input(
76
+ "Maximum length of the sequence",
77
+ min_value=1,
78
+ value=160,
79
+ help="The maximum length of the sequence to be generated",
80
+ )
81
+ temperature = st.sidebar.slider(
82
+ "Temperature",
83
+ min_value=0.0,
84
+ max_value=3.0,
85
+ step=0.01,
86
+ value=1.0,
87
+ help="The value used to module the next token probabilities",
88
+ )
89
+ top_p = st.sidebar.slider(
90
+ "Top-P",
91
+ min_value=0.0,
92
+ max_value=1.0,
93
+ step=0.01,
94
+ value=0.95,
95
+ help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
96
+ )
97
+
98
+ top_k= st.sidebar.number_input(
99
+ "Top-K",
100
+ min_value=0,
101
+ value=50,
102
+ step=1,
103
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
104
+ )
105
+
106
+ caption = (
107
+ "In [HuggingArtists](https://github.com/AlekseyKorshuk/huggingartist), we can generate lyrics by a specific artist. This was made by fine-tuning a pre-trained [HuggingFace Transformer](https://huggingface.co) on parsed datasets from [Genius](https://genius.com)."
108
+ )
109
+ st.markdown("[HuggingArtists](https://github.com/AlekseyKorshuk/huggingartist) - Train a model to generate lyrics 🎡")
110
+ st.markdown(caption)
111
+
112
+ st.subheader("Settings:")
113
+ artist_name = st.text_input("Artist name:", "Eminem")
114
+ start = st.text_input("Beginning of the song:", "But for me to rap like a computer")
115
+
116
+ TOKEN = "q_JK_BFy9OMiG7fGTzL-nUto9JDv3iXI24aYRrQnkOvjSCSbY4BuFIindweRsr5I"
117
+ genius = lyricsgenius.Genius(TOKEN)
118
+
119
+ model_html = """
120
+
121
+ <div class="inline-flex flex-col" style="line-height: 1.5;">
122
+ <div class="flex">
123
+ <div
124
+ \t\t\tstyle="display:DISPLAY_1; margin-left: auto; margin-right: auto; width: 92px; height:92px; border-radius: 50%; background-size: cover; background-image: url(&#39;USER_PROFILE&#39;)">
125
+ </div>
126
+ </div>
127
+ <div style="text-align: center; margin-top: 3px; font-size: 16px; font-weight: 800">πŸ€– HuggingArtists Model πŸ€–</div>
128
+ <div style="text-align: center; font-size: 16px; font-weight: 800">USER_NAME</div>
129
+ <a href="https://genius.com/artists/USER_HANDLE">
130
+ \t<div style="text-align: center; font-size: 14px;">@USER_HANDLE</div>
131
+ </a>
132
+ </div>
133
+ """
134
+
135
+
136
+ def post_process(output_sequences):
137
+ predictions = []
138
+ generated_sequences = []
139
+
140
+ max_repeat = 2
141
+
142
+ # decode prediction
143
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
144
+ generated_sequence = generated_sequence.tolist()
145
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
146
+ generated_sequences.append(text.strip())
147
+
148
+ for i, g in enumerate(generated_sequences):
149
+ res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
150
+ lines = res.split('\n')
151
+ # print(lines)
152
+ # i = max_repeat
153
+ # while i != len(lines):
154
+ # remove_count = 0
155
+ # for index in range(0, max_repeat):
156
+ # # print(i - index - 1, i - index)
157
+ # if lines[i - index - 1] == lines[i - index]:
158
+ # remove_count += 1
159
+ # if remove_count == max_repeat:
160
+ # lines.pop(i)
161
+ # i -= 1
162
+ # else:
163
+ # i += 1
164
+ predictions.append('\n'.join(lines))
165
+
166
+ return predictions
167
+
168
+ if st.button("Run"):
169
+ model_name = None
170
+ with st.spinner(text=f"Searching for {artist_name } in Genius..."):
171
+ artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False)
172
+ if artist is not None:
173
+ artist_dict = genius.artist(artist.id)['artist']
174
+ artist_url = str(artist_dict['url'])
175
+ model_name = artist_url[artist_url.rfind('/') + 1:].lower()
176
+ st.markdown(model_html.replace("USER_PROFILE",artist.image_url).replace("USER_NAME",artist.name).replace("USER_HANDLE",model_name), unsafe_allow_html=True)
177
+ else:
178
+ st.markdown(f"Could not find {artist_name}! Be sure that he/she exists in [Genius](https://genius.com/).")
179
+ if model_name is not None:
180
+ with st.spinner(text=f"Downloading the model of {artist_name }..."):
181
+ model = None
182
+ tokenizer = None
183
+ try:
184
+ tokenizer = AutoTokenizer.from_pretrained(f"huggingartists/{model_name}")
185
+ model = AutoModelForCausalLM.from_pretrained(f"huggingartists/{model_name}")
186
+ except Exception as ex:
187
+ # st.markdown(ex)
188
+ st.markdown(f"Model for this artist does not exist yet. Create it in just 5 min with [Colab Notebook](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb):")
189
+ st.markdown(
190
+ """
191
+ <style>
192
+ .aligncenter {
193
+ text-align: center;
194
+ }
195
+ </style>
196
+ <p class="aligncenter">
197
+ <a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank">
198
+ <img src="https://colab.research.google.com/assets/colab-badge.svg"/>
199
+ </a>
200
+ </p>
201
+ """,
202
+ unsafe_allow_html=True,
203
+ )
204
+ if model is not None:
205
+ with st.spinner(text=f"Generating lyrics..."):
206
+ encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids
207
+ encoded_prompt = encoded_prompt.to(model.device)
208
+ # prediction
209
+ output_sequences = model.generate(
210
+ input_ids=encoded_prompt,
211
+ max_length=max_length,
212
+ min_length=min_length,
213
+ temperature=float(temperature),
214
+ top_p=float(top_p),
215
+ top_k=int(top_k),
216
+ do_sample=True,
217
+ repetition_penalty=1.0,
218
+ num_return_sequences=num_sequences
219
+ )
220
+ # Post-processing
221
+ predictions = post_process(output_sequences)
222
+ st.subheader("Results")
223
+ for prediction in predictions:
224
+ st.text(prediction)
225
+ st.subheader("Please star this repository and join my Telegram Channel:")
226
+ st.markdown(
227
+ """
228
+ <style>
229
+ .aligncenter {
230
+ text-align: center;
231
+ }
232
+ </style>
233
+ <p class="aligncenter">
234
+ <a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">
235
+ <img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/>
236
+ </a>
237
+ </p>
238
+ <p class="aligncenter">
239
+ <a href="https://t.me/joinchat/_CQ04KjcJ-4yZTky" target="_blank">
240
+ <img src="https://img.shields.io/badge/dynamic/json?color=blue&label=Telegram%20Channel&query=%24.result&url=https%3A%2F%2Fapi.telegram.org%2Fbot1929545866%3AAAFGhV-KKnegEcLiyYJxsc4zV6C-bdPEBtQ%2FgetChatMemberCount%3Fchat_id%3D-1001253621662&style=social&logo=telegram"/>
241
+ </a>
242
+ </p>
243
+ """,
244
+ unsafe_allow_html=True,
245
+ )