4ervonec19 commited on
Commit
fbe1d6d
·
verified ·
1 Parent(s): 8abaca5

class file added

Browse files
Files changed (1) hide show
  1. inference_gptbigcode.py +230 -0
inference_gptbigcode.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ import logging
7
+ logging.getLogger("transformers").setLevel(logging.ERROR)
8
+
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+
12
+ prompt_string = "Generate test cases for a given Python function using its source code. The output should be a series of assert statements that verify the function's correctness. \n\nExample:\n\nInput:\n\ndef add(a, b):\n\treturn a + b\n \n\nOutput:\n\ndef test_add(a, b):\n\n\tassert add(1, 2) == 3\n\tassert add(-1, 1) == 0\n\tassert add(0, 0) == 0\n\tassert add(1.5, 2.5) == 4.0\n\nMy Input:\n\n"
13
+ max_length = 512
14
+ params_inference = {
15
+ 'max_new_tokens': 512,
16
+ 'do_sample': True,
17
+ 'top_k': 100,
18
+ 'top_p': 0.85,
19
+ }
20
+ num_lines = 5
21
+
22
+ gpt2_name="bigcode/gpt_bigcode-santacoder"
23
+ saved_model_path = '/Users/chervonikov_alexey/Desktop/result_gpt_bigcode_15_epochs/model_val_loss.pth'
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+
26
+ class LargeCodeModelGPTBigCode(nn.Module):
27
+ '''Класс для большой языковой модели, которая обрабатывает входной код'''
28
+ def __init__(self, gpt2_name=gpt2_name,
29
+ prompt_string = prompt_string,
30
+ params_inference = params_inference,
31
+ max_length = max_length,
32
+ device = device,
33
+ saved_model_path = saved_model_path,
34
+ num_lines = num_lines,
35
+ flag_hugging_face = False,
36
+ flag_pretrained = False):
37
+
38
+ '''
39
+ Конструктор класса. Необходим для инициализации модели
40
+
41
+ Параметры:
42
+ -gpt2_name: link модели на HuggingFace (default: "bigcode/gpt_bigcode-santacoder")
43
+ -prompt_string: дополнительная обертка для лучшего понимания моделью задачи (default: prompt_string)
44
+ -params_inference: параметры инференса (для использования self.gpt2.generate(**inputs,
45
+ **inference_params))
46
+ -max_length: максимальное число токенов в последовательности (default: 512)
47
+ -device: устройство (default: cuda or cpu)
48
+ -saved_model_path: путь к затюненной модели
49
+ -num_lines: число линий (ввиду "неконечной" генерации модели)
50
+ -flag_hugging_face: флаг для использования HuggingFace (default: False)
51
+ -flag_pretrained: флаг для инициализации модели затюненными весами
52
+
53
+ '''
54
+ super(LargeCodeModelGPTBigCode, self).__init__()
55
+
56
+ self.new_special_tokens = ['<FUNC_TOKEN>',
57
+ '<INFO_TOKEN>',
58
+ '<CLS_TOKEN>',
59
+ '<AST_TOKEN>',
60
+ '<DESCRIPTION_TOKEN>',
61
+ '<COMMENTS_TOKEN>']
62
+
63
+ self.special_tokens_dict = {
64
+ 'additional_special_tokens': self.new_special_tokens
65
+ }
66
+
67
+ self.tokenizerGPT = AutoTokenizer.from_pretrained(gpt2_name, padding_side='left')
68
+ self.tokenizerGPT.add_special_tokens({'pad_token': '<PAD>'})
69
+ self.tokenizerGPT.add_special_tokens(self.special_tokens_dict)
70
+ self.gpt2 = AutoModelForCausalLM.from_pretrained(gpt2_name)
71
+ self.gpt2.resize_token_embeddings(len(self.tokenizerGPT))
72
+
73
+ self.max_length = max_length
74
+ self.inference_params = params_inference
75
+ self.additional_prompt = prompt_string
76
+ self.pretrained_path = saved_model_path
77
+ self.device = device
78
+ self.num_lines = num_lines
79
+
80
+ if flag_pretrained == True and flag_hugging_face == False:
81
+ if self.device == "cuda":
82
+ self.load_state_dict(torch.load(self.pretrained_path))
83
+ else:
84
+ self.load_state_dict(torch.load(self.pretrained_path,
85
+ map_location=torch.device('cpu')))
86
+
87
+ # forward call
88
+ def forward(self, input_ids, attention_mask,
89
+ response_ids):
90
+
91
+ '''
92
+ Forward call method
93
+
94
+ Параметры:
95
+ -input_ids: входные токены
96
+ -attention_mask: маска внимания
97
+ -response_ids: метки
98
+
99
+ Returns:
100
+ -результат forward call
101
+
102
+ '''
103
+ gpt2_outputs = self.gpt2(input_ids = input_ids,
104
+ attention_mask = attention_mask,
105
+ labels = response_ids)
106
+
107
+ return gpt2_outputs
108
+
109
+ @staticmethod
110
+ def decode_sequence(tokens_ids, tokenizer):
111
+ '''
112
+ Декодирование последовательности токенов
113
+
114
+ Параметры:
115
+ -tokens_ids: последоавтельность токенов
116
+ -tokenizer: токенизатор
117
+
118
+ Returns:
119
+ -code_decoded: Декодированная последовательность
120
+ '''
121
+
122
+ code_decoded = tokenizer.decode(tokens_ids, skip_special_tokens = True)
123
+ return code_decoded
124
+
125
+ @staticmethod
126
+ def remove_before_substring(text, substring="My Output:"):
127
+ '''
128
+ Вспомогательная утилита, чтобы убрать все лишнее
129
+
130
+ Параметры:
131
+ -text: строка
132
+ -substring: подстрока (default: "My Output:")
133
+
134
+ Returns:
135
+ -text: обновленный текст
136
+
137
+ '''
138
+
139
+ index = text.find(substring)
140
+ if index != -1:
141
+ # Вернуть часть строки, начиная с подстроки
142
+ return text[index:]
143
+ return text
144
+
145
+ @staticmethod
146
+ def extract_text_between_markers(text, start_marker, end_marker):
147
+ '''
148
+ Утилита для работы с входной строкой (получение чистой входной функции)
149
+
150
+ Параметры:
151
+ -text: входная строка
152
+ -start_marker: стартовый маркер
153
+ -end_marker: конечный маркер
154
+
155
+ Returns:
156
+ -Отфильтрованный текст
157
+
158
+ '''
159
+
160
+ start_index = text.find(start_marker)
161
+ end_index = text.find(end_marker)
162
+ if start_index == -1 or end_index == -1 or start_index >= end_index:
163
+ return None
164
+ return text[start_index + len(start_marker):end_index].strip()
165
+
166
+ def input_inference(self, code_text):
167
+ '''
168
+ Инференс входной строки кода
169
+
170
+ Параметры:
171
+ -code_text: строка с кодом
172
+
173
+ Returns:
174
+ -dict: {
175
+ 'input_function': input_function,
176
+ 'generated_output': output_string
177
+ }
178
+
179
+ Словарь в формате input_function + generated_output
180
+ '''
181
+ model_input = self.additional_prompt + code_text + "\n\nMy Output:\n\n"
182
+
183
+ def encode_text(text, tokenizer = self.tokenizerGPT):
184
+ encoding = tokenizer.encode_plus(
185
+ text,
186
+ add_special_tokens=True,
187
+ max_length=max_length,
188
+ padding='max_length',
189
+ truncation=True,
190
+ return_attention_mask=True,
191
+ return_tensors='pt',
192
+ )
193
+ input_ids_code_text = encoding['input_ids'].flatten()
194
+ attention_mask_code_text = encoding['attention_mask'].flatten()
195
+ return input_ids_code_text, attention_mask_code_text
196
+
197
+ input_ids_focal_method, attention_mask_focal_method = encode_text(model_input)
198
+
199
+ self.eval()
200
+ with torch.no_grad():
201
+
202
+ inputs = {'input_ids': input_ids_focal_method.unsqueeze(0).to(device),
203
+ 'attention_mask': attention_mask_focal_method.unsqueeze(0).to(device)}
204
+
205
+ input_function = self.decode_sequence(tokens_ids = inputs['input_ids'][0],
206
+ tokenizer=self.tokenizerGPT)
207
+
208
+ input_function = self.extract_text_between_markers(input_function,
209
+ 'My Input:',
210
+ 'My Output:'),
211
+
212
+ output = self.gpt2.generate(**inputs,
213
+ **self.inference_params)
214
+
215
+ output_string = self.decode_sequence(tokens_ids = output[0],
216
+ tokenizer=self.tokenizerGPT)
217
+
218
+ output_string = self.remove_before_substring(output_string) \
219
+ .replace('My Output:', "") \
220
+ .strip()
221
+ output_string = "\n".join(output_string.splitlines()[:self.num_lines])
222
+
223
+ return {
224
+ 'input_function': input_function,
225
+ 'generated_output': output_string
226
+ }
227
+
228
+ if __name__ == "__main__":
229
+ CodeModel = LargeCodeModelGPTBigCode(gpt2_name, flag_pretrained=True)
230
+