MasterAlex69 commited on
Commit
3fb1820
·
verified ·
1 Parent(s): 3ac4624

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip -q install gradio
2
+
3
+ import gradio as gr
4
+ from transformers import pipeline, GPT2Tokenizer, AutoModelForSequenceClassification, AutoTokenizer
5
+ from IPython.display import clear_output
6
+ import joblib, torch
7
+
8
+ ############################################################################################
9
+
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ generator_name_0 = 'MasterAlex69/gpt2_edline'
13
+ generator_name_1 = 'MasterAlex69/gpt2_edline_gan'
14
+
15
+ generator_tokenizer_0 = GPT2Tokenizer.from_pretrained(generator_name_0)
16
+ generator_tokenizer_1 = GPT2Tokenizer.from_pretrained(generator_name_1)
17
+
18
+ generator_tokenizer_0.pad_token_id = generator_tokenizer_0.eos_token_id
19
+ generator_tokenizer_1.pad_token_id = generator_tokenizer_1.eos_token_id
20
+
21
+ generator_pipeline_0 = pipeline('text-generation', model = generator_name_0, tokenizer = generator_tokenizer_0, device = device)
22
+ generator_pipeline_1 = pipeline('text-generation', model = generator_name_1, tokenizer = generator_tokenizer_1, device = device)
23
+
24
+ generator_pkl_name_0 = 'generator_pkl_0.pkl'
25
+ generator_pkl_name_1 = 'generator_pkl_1.pkl'
26
+
27
+ joblib.dump(generator_pipeline_0, generator_pkl_name_0)
28
+ joblib.dump(generator_pipeline_1, generator_pkl_name_1)
29
+
30
+ generator_pipeline_0 = joblib.load('/content/' + generator_pkl_name_0)
31
+ generator_pipeline_1 = joblib.load('/content/' + generator_pkl_name_1)
32
+
33
+ ############################################################################################
34
+
35
+ discriminator_name_0 = 'MasterAlex69/bert_edline'
36
+ discriminator_name_1 = 'MasterAlex69/bert_edline_gan'
37
+
38
+ discriminator_0 = AutoModelForSequenceClassification.from_pretrained(discriminator_name_0, ).to(device)
39
+ discriminator_1 = AutoModelForSequenceClassification.from_pretrained(discriminator_name_1).to(device)
40
+
41
+ discriminator_tokenizer_0 = AutoTokenizer.from_pretrained(discriminator_name_0)
42
+ discriminator_tokenizer_1 = AutoTokenizer.from_pretrained(discriminator_name_1)
43
+
44
+ discriminator_pkl_name_0 = 'discriminator_pkl_0.pkl'
45
+ discriminator_pkl_name_1 = 'discriminator_pkl_1.pkl'
46
+
47
+ joblib.dump(discriminator_0, discriminator_pkl_name_0)
48
+ joblib.dump(discriminator_1, discriminator_pkl_name_1)
49
+
50
+ discriminator_0 = joblib.load('/content/' + discriminator_pkl_name_0)
51
+ discriminator_1 = joblib.load('/content/' + discriminator_pkl_name_1)
52
+
53
+ discriminator_pkl_tokenizer_name_0 = 'discriminator_tokenizer_pkl_0.pkl'
54
+ discriminator_pkl_tokenizer_name_1 = 'discriminator_tokenizer_pkl_1.pkl'
55
+
56
+ joblib.dump(discriminator_tokenizer_0, discriminator_pkl_tokenizer_name_0)
57
+ joblib.dump(discriminator_tokenizer_1, discriminator_pkl_tokenizer_name_1)
58
+
59
+ discriminator_tokenizer_0 = joblib.load('/content/' + discriminator_pkl_tokenizer_name_0)
60
+ discriminator_tokenizer_1 = joblib.load('/content/' + discriminator_pkl_tokenizer_name_1)
61
+
62
+ ############################################################################################
63
+ def generate_text_0():
64
+ return generator_pipeline_0("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
65
+
66
+ def generate_text_1():
67
+ return generator_pipeline_1("Строка состоит из символов", max_length = 225, truncation = False)[0]['generated_text']
68
+
69
+ def discriminate_text_0(text):
70
+
71
+ inputs = discriminator_tokenizer_0(text
72
+ , return_tensors = "pt"
73
+ , padding = True
74
+ , truncation = True).to(device)
75
+
76
+ result = discriminator_0(**inputs).logits[:, -1]
77
+ return torch.round(torch.sigmoid(result)).long().tolist()[0]
78
+
79
+ def discriminate_text_1(text):
80
+
81
+ inputs = discriminator_tokenizer_1(text
82
+ , return_tensors = "pt"
83
+ , padding = True
84
+ , truncation = True).to(device)
85
+
86
+ result = discriminator_1(**inputs).logits[:, -1]
87
+ return torch.round(torch.sigmoid(result)).long().tolist()[0]
88
+
89
+ def d_test_0(count):
90
+
91
+ if count == "": count = 0
92
+ count = int(count)
93
+ if count == 0: return 'Введите количество итераций...'
94
+ if count > 256: return 'Максимальное количество итераций: 256.'
95
+ result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
96
+ texts = [item['generated_text'] for sublist in result for item in sublist]
97
+ results = [discriminate_text_0(text) for text in texts]
98
+
99
+ i = 0
100
+ m = 0
101
+ for result in results:
102
+
103
+ real_result = 0
104
+ if get_correct_answer(texts[i]).find('(не корректно)') == -1: real_result = 1
105
+ if result == real_result: m += 1
106
+ i += 1
107
+
108
+ return str(round(m / count * 100, 2)) + '%'
109
+
110
+ def d_test_1(count):
111
+
112
+ if count == "": count = 0
113
+ count = int(count)
114
+ if count == 0: return 'Введите количество итераций...'
115
+ if count > 256: return 'Максимальное количество итераций: 256.'
116
+ result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
117
+ texts = [item['generated_text'] for sublist in result for item in sublist]
118
+ results = [discriminate_text_1(text) for text in texts]
119
+
120
+ i = 0
121
+ m = 0
122
+ for result in results:
123
+
124
+ real_result = 0
125
+ if get_correct_answer(texts[i]).find('(не корректно)') == -1: real_result = 1
126
+ if result == real_result: m += 1
127
+ i += 1
128
+
129
+ return str(round(m / count * 100, 2)) + '%'
130
+
131
+
132
+ def test(count):
133
+
134
+ if count == "": count = 0
135
+
136
+ right = 0
137
+ count = int(count)
138
+ if count == 0: return 'Введите количество итераций...'
139
+ if count > 256: return 'Максимальное количество итераций: 256.'
140
+
141
+ result = generator_pipeline_1(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
142
+ texts = [item['generated_text'] for sublist in result for item in sublist]
143
+
144
+ for text in texts:
145
+ if get_correct_answer(text).find('не корректно') == -1: right += 1
146
+
147
+ return str(round(right / count * 100, 2)) + '%'
148
+
149
+ def test_base(count):
150
+
151
+ if count == "": count = 0
152
+
153
+ right = 0
154
+ count = int(count)
155
+ if count == 0: return 'Введите количество итераций...'
156
+ if count > 256: return 'Максимальное количество итераций: 256.'
157
+
158
+ result = generator_pipeline_0(['Строка состоит из символов'] * count, max_length = 225, batch_size = count)
159
+ texts = [item['generated_text'] for sublist in result for item in sublist]
160
+
161
+ for text in texts:
162
+ if get_correct_answer(text).find('не корректно') == -1: right += 1
163
+
164
+ return str(round(right / count * 100, 2)) + '%'
165
+
166
+ def get_correct_answer(t):
167
+
168
+ if len(t) == 0: return 'Введите задание...'
169
+
170
+ start_index = t.find("(")
171
+ end_index = t.find(")", start_index)
172
+ a = t[start_index + 8: end_index]
173
+
174
+ start_index = t.find("д символов ")
175
+ end_index = t.find(".", start_index)
176
+ c = t[start_index + 11 : end_index]
177
+
178
+ start_index = t.find("а: ")
179
+ end_index = t.find(".", start_index)
180
+ t = t[start_index + 3: end_index]
181
+
182
+ t = t.replace(c, '*')
183
+
184
+ max_length = 0
185
+ current_length = 0
186
+
187
+ for char in t:
188
+ if char == '*':
189
+
190
+ current_length += 1
191
+ if current_length > max_length: max_length = current_length
192
+
193
+ else: current_length = 0
194
+
195
+ return str(max_length) + (' (корректно)' if str(max_length) == a else ' (не корректно)')
196
+
197
+ ############################################################################################
198
+ with gr.Blocks(theme = gr.themes.Monochrome()) as iface:
199
+
200
+ with gr.Row():
201
+ with gr.Column():
202
+ button_gen_0 = gr.Button("Сгенерировать задание (ДО)")
203
+ button_gen_0_output_text = gr.Textbox(label = "Результат генерации", interactive = False)
204
+ button_gen_0.click(fn = generate_text_0, outputs = button_gen_0_output_text)
205
+
206
+ with gr.Column():
207
+ button_gen_1 = gr.Button("Сгенерировать задание (ПОСЛЕ)")
208
+ button_gen_1_output_text = gr.Textbox(label="Результат генерации", interactive = False)
209
+ button_gen_1.click(fn = generate_text_1, outputs = button_gen_1_output_text)
210
+
211
+ with gr.Row():
212
+ with gr.Column():
213
+ button_test = gr.Button("Провести испытание (ДО) генератор")
214
+ test_input_text = gr.Textbox(label = "Количество итераций")
215
+ test_output_text = gr.Textbox(label = "Корректных заданий")
216
+ button_test.click(fn = test_base, inputs = test_input_text, outputs = test_output_text)
217
+
218
+ with gr.Column():
219
+ button_test_ = gr.Button("Провести испытание (ПОСЛЕ) генератор")
220
+ test_input_text_ = gr.Textbox(label = "Количество итераций")
221
+ test_output_text_ = gr.Textbox(label = "Корректных заданий")
222
+ button_test_.click(fn = test, inputs = test_input_text_, outputs = test_output_text_)
223
+
224
+ with gr.Row():
225
+ with gr.Column():
226
+ button_get_correct_answer = gr.Button("Получить правильный ответ")
227
+ get_correct_answer_input_text = gr.Textbox(label = "Задание")
228
+ get_correct_answer_output_text = gr.Textbox(label = "Ответ")
229
+ button_get_correct_answer.click(fn = get_correct_answer, inputs = get_correct_answer_input_text, outputs = get_correct_answer_output_text)
230
+
231
+ with gr.Row():
232
+ with gr.Column():
233
+ bn_test_d_0 = gr.Button("Провести испытание (ДО) дискриминатор")
234
+ bn_test_d_0_text_input = gr.Textbox(label = "Количество итераций")
235
+ bn_test_d_0_text_output = gr.Textbox(label = "Совпадений")
236
+ bn_test_d_0.click(fn = d_test_0, inputs = bn_test_d_0_text_input, outputs = bn_test_d_0_text_output)
237
+
238
+ with gr.Column():
239
+ bn_test_d_1 = gr.Button("Провести испытание (ПОСЛЕ) дискриминатор")
240
+ bn_test_d_1_text_input = gr.Textbox(label = "Количество итераций")
241
+ bn_test_d_1_text_output = gr.Textbox(label = "Совпадений")
242
+ bn_test_d_1.click(fn = d_test_1, inputs = bn_test_d_1_text_input, outputs = bn_test_d_1_text_output)
243
+
244
+ clear_output()
245
+ iface.launch(share = True, debug = False)