ElijahDi commited on
Commit
c12d385
·
verified ·
1 Parent(s): 82b7b3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -202,24 +202,26 @@ elif selected_model == "Генерация текста GPT-моделью по
202
  model = GPT2LMHeadModel.from_pretrained(path).to(device)
203
  tokenizer = GPT2Tokenizer.from_pretrained(path)
204
 
205
- if st.button('Сделать гороскоп'):
206
- with st.spinner('Генерация текста...'):
207
- start_time = time.time()
208
- input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
209
- model.eval()
210
- with torch.no_grad():
211
- out = model.generate(
212
- input_ids,
213
- do_sample=True,
214
- num_beams=2,
215
- temperature=1.1,
216
- top_p=0.9,
217
- max_length=50,
218
- )
219
-
220
- generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
221
- end_time = time.time()
222
- prediction_time = end_time - start_time
223
 
224
- st.success('Готово!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  st.write(f'{generated_text}')
 
202
  model = GPT2LMHeadModel.from_pretrained(path).to(device)
203
  tokenizer = GPT2Tokenizer.from_pretrained(path)
204
 
205
+ temperature = st.slider('Temperature', 0.1, 2.0, 1.1, step=0.1)
206
+ max_gen_length = st.slider('Максимальная длина генерации', 10, 500, 100, step=10)
207
+ num_generations = st.slider('Количество генераций', 1, 10, 2, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ if st.button('Сделать гороскоп'):
210
+ with st.spinner('Генерация текста...'):
211
+ start_time = time.time()
212
+ input_ids = tokenizer.encode(user_text_input, return_tensors="pt").to(device)
213
+ model.eval()
214
+ with torch.no_grad():
215
+ out = model.generate(
216
+ input_ids,
217
+ do_sample=True,
218
+ num_beams=num_generations,
219
+ temperature=temperature,
220
+ top_p=0.9,
221
+ max_length=max_gen_length,
222
+ )
223
+ generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
224
+ end_time = time.time()
225
+ prediction_time = end_time - start_time
226
+
227
  st.write(f'{generated_text}')