SaviAnna commited on
Commit
460d569
·
verified ·
1 Parent(s): 8b3e307

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import streamlit as st
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+
8
+ st.title("""
9
+ Fine-tuned GPT-2 for New Language with Custom Tokenizer
10
+ """)
11
+ # Добавление слайдера
12
+ temperature = st.slider("Temerature", 1, 20, 1)
13
+ max_len = st.slider("Length", 40, 120, 2)
14
+ # Загрузка модели и токенизатора
15
+ # model = GPT2LMHeadModel.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
16
+ # tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
17
+ # #Задаем класс модели (уже в streamlit/tg_bot)
18
+
19
+ @st.cache
20
+ # def load_gpt():
21
+ # model_GPT = GPT2LMHeadModel.from_pretrained(
22
+ # 'sberbank-ai/rugpt3small_based_on_gpt2',
23
+ # output_attentions = False,
24
+ # output_hidden_states = False,
25
+ # )
26
+ # tokenizer_GPT = GPT2Tokenizer.from_pretrained(
27
+ # 'sberbank-ai/rugpt3small_based_on_gpt2',
28
+ # output_attentions = False,
29
+ # output_hidden_states = False,
30
+ # )
31
+ # gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
32
+ # model_GPT.load_state_dict(torch.load('model_history_friday.pt', map_location=torch.device('cpu')))
33
+ # return model_GPT, tokenizer_GPT
34
+ def load_gpt_base():
35
+ model_GPT = GPT2LMHeadModel.from_pretrained("gpt2")
36
+ tokenizer_GPT = GPT2TokenizerFast.from_pretrained("gpt2")
37
+ return model_GPT, tokenizer_GPT
38
+
39
+ # # Вешаем сохраненные веса на нашу модель
40
+
41
+ # Функция для генерации текста
42
+ def generate_text(model_GPT, tokenizer_GPT, prompt):
43
+ # Преобразование входной строки в токены
44
+ input_ids = tokenizer_GPT.encode(prompt, return_tensors='pt')
45
+
46
+ # Генерация текста
47
+ output = model_GPT.generate(input_ids=input_ids, max_length=70, num_beams=5, do_sample=True,
48
+ temperature=1., top_k=50, top_p=0.6, no_repeat_ngram_size=3,
49
+ num_return_sequences=3)
50
+
51
+ # Декодирование сгенерированного текста
52
+ generated_text = tokenizer_GPT.decode(output[0], skip_special_tokens=True)
53
+
54
+ return generated_text
55
+
56
+ # Streamlit приложение
57
+ def main():
58
+ model_GPT, tokenizer_GPT = load_gpt()
59
+ st.write("""
60
+ # Fine-tuned GPT-2 for New Language with Custom Tokenizer
61
+ """)
62
+
63
+ # Ввод строки пользователем
64
+ prompt = st.text_area("Какую фразу нужно продолжить:", value="В средние века")
65
+
66
+ # # Генерация текста по введенной строке
67
+ # generated_text = generate_text(prompt)
68
+ # Создание кнопки "Сгенерировать"
69
+ generate_button = st.button("Complete!")
70
+ # Обработка события нажатия кнопки
71
+ if generate_button:
72
+ # Вывод сгенерированного текста
73
+ #generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
74
+ generated_text = 'test'
75
+ st.subheader("Completed prompt:")
76
+ st.write(generated_text)
77
+
78
+ # Ввод строки пользователем
79
+ prompt1 = st.text_area("Какую фразу нужно продолжить:", value="В средние века")
80
+ # # Генерация текста по введенной строке
81
+ # generated_text = generate_text(prompt)
82
+ # Создание кнопки "Сгенерировать"
83
+ generate_button1 = st.button("Complete!")
84
+ # Обработка события нажатия кнопки
85
+ if generate_button1:
86
+ # Вывод сгенерированного текста
87
+ #generated_text = generate_text(model_GPT, tokenizer_GPT, prompt)
88
+ generated_text = 'test'
89
+ st.subheader("Completed prompt:")
90
+ st.write(generated_text)
91
+
92
+ if __name__ == "__main__":
93
+ main()