ElijahDi commited on
Commit
1f92bd2
·
verified ·
1 Parent(s): 276fc12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -8,8 +8,9 @@ from torch import tensor
8
 
9
  import joblib
10
  from dataclasses import dataclass
11
- from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
12
  import json
 
13
 
14
  from preprocessing import predict_review, data_preprocessing_hard
15
  from model_lstm import LSTMClassifier
@@ -173,17 +174,39 @@ elif selected_model == "Оценка степени токсичности по
173
 
174
 
175
 
176
- # Генерация текста GPT-моделью
177
  elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt":
178
  st.title("""
179
- Приложение генерирует текст по Вашему промту
180
  """)
181
 
182
  st.write("""
183
- Для генерации текста используется предобученная сеть GPT.
 
184
  """)
185
- uploaded_img = st.sidebar.file_uploader('Загрузи свое космофото', type=["jpg", "png", "jpeg"])
186
- if uploaded_img is not None:
187
- input_img = io.imread(uploaded_img)
188
- else:
189
- input_img = io.imread('/Users/id/Documents/strlit/cv_project/Segm.jpg')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  import joblib
10
  from dataclasses import dataclass
11
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer
12
  import json
13
+ import os
14
 
15
  from preprocessing import predict_review, data_preprocessing_hard
16
  from model_lstm import LSTMClassifier
 
174
 
175
 
176
 
177
+ # Генерация текста GPT
178
  elif selected_model == "Генерация текста GPT-моделью по пользовательскому prompt":
179
  st.title("""
180
+ Нейросетевой гороскоп Якубовский-Дьяченко
181
  """)
182
 
183
  st.write("""
184
+ Для генерации текста используется предобученная сеть GPT2. Дообучение проходило на гороскопах.
185
+ Общая длина текста для обучения 37 001 887 слов.
186
  """)
187
+ user_text_input = st.text_area('Введите информацию о себе для формиорования гороскопа:')
188
+
189
+ # GPT2
190
+ model_path = "model.safetensors"
191
+ huggingface_token = os.getenv("HF_TOKEN")
192
+ model = GPT2LMHeadModel.from_pretrained(model_path, token=huggingface_token)
193
+ tokenizer = GPT2Tokenizer.from_pretrained(model_path, token=huggingface_token)
194
+
195
+ if st.button('Сделать гороскоп'):
196
+ start_time = time.time()
197
+ input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
198
+ model.eval()
199
+ with torch.no_grad():
200
+ out = model.generate(input_ids,
201
+ do_sample=True,
202
+ num_beams=2,
203
+ temperature=1.1,
204
+ top_p=0.9,
205
+ max_length=50,
206
+ )
207
+
208
+ generated_text = list(map(tokenizer.decode, out))[0]
209
+ end_time = time.time()
210
+ prediction_time = end_time - start_time
211
+
212
+ st.write(f'Ваше предсказание: {generated_text}')