SaviAnna commited on
Commit
12d19c9
·
verified ·
1 Parent(s): 9177aa5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -31
app.py CHANGED
@@ -8,54 +8,69 @@ st.title("""
8
  """)
9
 
10
  # Слайдеры для управления температурой и длиной текста
11
- temperature = st.slider("Temperature", 0.1, 2.0, 1.0)
12
- max_len = st.slider("Max Length", 40, 120, 70)
13
 
14
- # Кеширование модели и токенизатора
15
  @st.cache_resource
16
- def load_gpt_base():
17
- model_GPT = GPT2LMHeadModel.from_pretrained("gpt2")
18
- tokenizer_GPT = GPT2TokenizerFast.from_pretrained("gpt2")
19
- return model_GPT, tokenizer_GPT
 
 
 
 
 
 
 
 
20
 
21
  # Функция для генерации текста
22
- def generate_text(model_GPT, tokenizer_GPT, prompt, max_len, temperature):
23
- # Преобразование входной строки в токены
24
- input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
25
 
26
  # Генерация текста
27
- output = model_GPT.generate(input_ids=input_ids,
28
- max_length=max_len,
29
- do_sample=True,
30
- temperature=temperature,
31
- top_k=50,
32
- top_p=0.6,
33
- no_repeat_ngram_size=3,
34
- num_return_sequences=1)
35
-
36
- # Декодирование сгенерированного текста
37
- generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
38
  return generated_text
39
 
40
  # Streamlit приложение
41
  def main():
42
- model_GPT, tokenizer_GPT = load_gpt_base()
 
43
 
44
  st.write("""
45
  # Fine-tuned GPT-2 for New Language with Custom Tokenizer
46
  """)
47
 
48
- # Ввод строки пользователем для генерации текста
49
- prompt = st.text_area("Введите фразу для генерации:", value="В средние века")
 
 
 
 
 
 
 
50
 
51
- # Создание кнопки для генерации
52
- generate_button = st.button("Сгенерировать текст")
 
 
53
 
54
- # Обработка события нажатия кнопки
55
- if generate_button:
56
- generated_text = generate_text(model_GPT, tokenizer_GPT, prompt, max_len, temperature)
57
- st.subheader("Сгенерированный текст:")
58
- st.write(generated_text)
59
 
60
  if __name__ == "__main__":
61
  main()
 
8
  """)
9
 
10
  # Слайдеры для управления температурой и длиной текста
11
+ temperature = st.slider("Temperature", 0.1, 2.0, 1.0) # Для обеих моделей
12
+ max_len = st.slider("Max Length", 40, 120, 70) # Для обеих моделей
13
 
14
+ # Кеширование модели и токенизатора GPT-2
15
  @st.cache_resource
16
+ def load_gpt2():
17
+ model_gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
18
+ tokenizer_gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
19
+ return model_gpt2, tokenizer_gpt2
20
+
21
+ # Кеширование кастомной модели и токенизатора
22
+ @st.cache_resource
23
+ def load_custom_model():
24
+ # Здесь замените путь на вашу кастомную модель
25
+ model_custom = GPT2LMHeadModel.from_pretrained("rus_gpt2_tuned")
26
+ tokenizer_custom = GPT2TokenizerFast.from_pretrained("rus_gpt2_tuned")
27
+ return model_custom, tokenizer_custom
28
 
29
  # Функция для генерации текста
30
+ def generate_text(model, tokenizer, prompt, max_len, temperature):
31
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
 
32
 
33
  # Генерация текста
34
+ output = model.generate(input_ids=input_ids,
35
+ max_length=max_len,
36
+ do_sample=True,
37
+ temperature=temperature,
38
+ top_k=50,
39
+ top_p=0.6,
40
+ no_repeat_ngram_size=3,
41
+ num_return_sequences=1)
42
+
43
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
44
  return generated_text
45
 
46
  # Streamlit приложение
47
  def main():
48
+ model_gpt2, tokenizer_gpt2 = load_gpt2() # GPT-2 модель
49
+ model_custom, tokenizer_custom = load_custom_model() # Кастомная модель
50
 
51
  st.write("""
52
  # Fine-tuned GPT-2 for New Language with Custom Tokenizer
53
  """)
54
 
55
+ # Блок для генерации текста с GPT-2
56
+ st.subheader("GPT-2 Text Generation")
57
+ prompt_gpt2 = st.text_area("Введите фразу для GPT-2 генерации:", value="В средние века")
58
+ generate_button_gpt2 = st.button("Сгенерировать текст с GPT-2")
59
+
60
+ if generate_button_gpt2:
61
+ generated_text_gpt2 = generate_text(model_gpt2, tokenizer_gpt2, prompt_gpt2, max_len, temperature)
62
+ st.subheader("Результат генерации GPT-2:")
63
+ st.write(generated_text_gpt2)
64
 
65
+ # Блок для генерации текста с кастомной моделью
66
+ st.subheader("Custom Model Text Generation")
67
+ prompt_custom = st.text_area("Введите фразу для генерации с кастомной моделью:", value="Когда-то давно")
68
+ generate_button_custom = st.button("Сгенерировать текст с кастомной моделью")
69
 
70
+ if generate_button_custom:
71
+ generated_text_custom = generate_text(model_custom, tokenizer_custom, prompt_custom, max_len, temperature)
72
+ st.subheader("Результат генерации с кастомной моде��ью:")
73
+ st.write(generated_text_custom)
 
74
 
75
  if __name__ == "__main__":
76
  main()