research14 commited on
Commit
9af2839
·
1 Parent(s): e9ec0aa

updated run_llm.py

Browse files
Files changed (1) hide show
  1. run_llm.py +303 -74
run_llm.py CHANGED
@@ -13,37 +13,41 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM,
13
  from fastchat.model import load_model, get_conversation_template, add_model_args
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  openai.api_key = "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"
17
 
18
 
19
  # determinant vs. determiner
20
  # https://wikidiff.com/determiner/determinant
21
- ents_prompt = [
22
- 'Noun',
23
- 'Verb',
24
- 'Adjective',
25
- 'Adverb',
26
- 'Preposition/Subord',
27
- 'Coordinating Conjunction',
28
- # 'Cardinal Number',
29
  'Determiner',
30
- 'Noun Phrase',
31
- 'Verb Phrase',
32
- 'Adjective Phrase',
33
- 'Adverb Phrase',
34
- 'Preposition Phrase',
35
- 'Conjunction Phrase',
36
- 'Coordinate Phrase',
37
- 'Quantitave Phrase',
38
- 'Complex Nominal',
39
- 'Clause',
40
- 'Dependent Clause',
41
- 'Fragment Clause',
42
- 'T-unit',
43
- 'Complex T-unit',
44
- # 'Fragment T-unit',
45
- ]
46
- ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT']
47
 
48
 
49
  model_mapping = {
@@ -53,34 +57,43 @@ model_mapping = {
53
  'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
54
  'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
55
  'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
56
- # 'llama2': 'meta-llama/Llama-2-7b-chat-hf',
57
- 'llama-7b': '/data/jiali/llama/hf/7B',
58
- 'llama-13b': '/data/jiali/llama/hf/13B',
59
- 'llama-30b': '/data/jiali/llama/hf/30B',
60
- 'llama-65b': '/data/jiali/llama/hf/65B',
61
- 'alpaca': '/data/jiali/alpaca-7B',
 
 
62
  # 'koala-7b': 'koala-7b',
63
  # 'koala-13b': 'koala-13b',
64
  }
65
 
66
- for m in model_mapping.keys():
67
- for eid, ent in enumerate(ents):
68
- os.makedirs(f'result/openai_result/{m}/ptb/per_ent/{ent}', exist_ok=True)
69
- os.makedirs(f'result/structured_prompt/{m}/ptb', exist_ok=True)
 
 
 
 
 
 
 
70
 
71
 
72
  # s = int(sys.argv[1])
73
  # e = int(sys.argv[2])
74
 
75
- s = 0
76
- e = 1000
77
  with open('sample_uniform_1k_2.txt', 'r') as f:
78
  selected_idx = f.readlines()
79
- selected_idx = [int(i.strip()) for i in selected_idx][s:e]
80
 
81
 
82
  ptb = []
83
- with open('ptb.jsonl', 'r') as f:
84
  for l in f:
85
  ptb.append(json.loads(l))
86
 
@@ -90,8 +103,19 @@ template_all = '''Please output the <Noun, Verb, Adjective, Adverb, Preposition/
90
  template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
91
 
92
  ## Prompt 2
 
 
 
 
 
 
 
 
 
93
  with open('demonstration_3_42_chunk.txt', 'r') as f:
94
- demonstration = f.read()
 
 
95
 
96
 
97
  def para(m):
@@ -102,11 +126,14 @@ def para(m):
102
 
103
  def main(args=None):
104
 
105
- if 'gpt3' in args.model:
 
 
 
106
  pass
107
 
108
  else:
109
- path = model_mapping[args.model]
110
  model, tokenizer = load_model(
111
  path,
112
  args.device,
@@ -118,64 +145,267 @@ def main(args=None):
118
  debug=args.debug,
119
  )
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if args.prompt == 1:
122
- for gid in tqdm(selected_idx, desc='Query'):
123
  text = ptb[gid]['text']
124
 
125
  for eid, ent in enumerate(ents):
126
- # if os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.pkl') or \
127
- # os.path.exists(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt'):
128
- # print(gid, ent, 'skip')
129
- # continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  ## Get prompt
132
  msg = template_single.format(ents_prompt[eid], text)
133
 
134
- if 'gpt' in args.model:
135
- prompt = msg
136
-
137
- elif 'vicuna' in args.model or 'alpaca' in args.model or 'fastchat-t5' in args.model:
138
- conv = get_conversation_template(args.model)
 
 
 
 
 
 
 
 
 
 
139
  conv.append_message(conv.roles[0], msg)
140
  conv.append_message(conv.roles[1], None)
141
  conv.system = ''
142
  prompt = conv.get_prompt().strip()
 
143
 
144
- elif 'llama-' in args.model:
145
- prompt = '### Human: ' + msg + ' ### Assistant:'
146
 
147
 
148
- ## Run
149
- if 'gpt3' in args.model:
150
- outputs = gpt3(prompt)
151
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  else:
 
 
 
 
 
 
153
  outputs = fastchat(prompt, model, tokenizer)
154
 
155
- with open(f'result/openai_result/{args.model}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
 
156
  f.write(outputs)
157
 
 
 
 
 
 
 
 
158
 
159
- if args.prompt == 2:
160
- for gid in tqdm(selected_idx, desc='Query'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  text = ptb[gid]['text']
 
 
 
 
 
 
 
162
 
163
- if os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.pkl') or \
164
- os.path.exists(f'result/structured_prompt/{args.model}/ptb/{gid}.txt'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  print(gid, 'skip')
166
  continue
167
 
168
- prompt = demonstration + '\n' + text
169
 
170
- if 'gpt3' in args.model:
171
  outputs = gpt3(prompt)
 
 
 
 
172
 
173
  else:
174
- outputs = fastchat(prompt, model, tokenizer)
 
 
175
 
176
- with open(f'result/structured_prompt/{args.model}/ptb/{gid}.txt', 'w') as f:
 
 
 
 
 
 
 
 
 
 
 
 
177
  f.write(outputs)
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def fastchat(prompt, model, tokenizer):
181
  input_ids = tokenizer([prompt]).input_ids
@@ -205,16 +435,15 @@ def fastchat(prompt, model, tokenizer):
205
  def gpt3(prompt):
206
  try:
207
  response = openai.ChatCompletion.create(
208
- model=args.model, messages=[{"role": "user", "content": prompt}])
209
 
210
- return response
211
 
212
  except Exception as err:
213
  print('Error')
214
  print(err)
215
 
216
- # time.sleep(1)
217
- raise
218
 
219
 
220
  if __name__ == "__main__":
@@ -226,13 +455,13 @@ if __name__ == "__main__":
226
  parser.add_argument("--debug", action="store_true")
227
  parser.add_argument("--message", type=str, default="Hello! Who are you?")
228
  parser.add_argument("--start", type=int, default=0)
229
- parser.add_argument("--end", type=int, default=1)
230
- parser.add_argument("--model", required=True, type=str, default=None)
231
  parser.add_argument("--prompt", required=True, type=int, default=None)
 
232
  args = parser.parse_args()
233
 
234
  # Reset default repetition penalty for T5 models.
235
- if "t5" in args.model and args.repetition_penalty == 1.0:
236
  args.repetition_penalty = 1.2
237
 
238
  main(args)
 
13
  from fastchat.model import load_model, get_conversation_template, add_model_args
14
 
15
 
16
+ from nltk.tag.mapping import _UNIVERSAL_TAGS
17
+
18
+ uni_tags = list(_UNIVERSAL_TAGS)
19
+ uni_tags[-1] = 'PUNC'
20
+
21
+ bio_tags = ['B', 'I', 'O']
22
+ chunk_tags = ['ADJP', 'ADVP', 'CONJP', 'INTJ', 'LST', 'NP', 'O', 'PP', 'PRT', 'SBAR', 'UCP', 'VP']
23
+
24
+ syntags = ['NP', 'S', 'VP', 'ADJP', 'ADVP', 'SBAR', 'TOP', 'PP', 'POS', 'NAC', "''", 'SINV', 'PRN', 'QP', 'WHNP', 'RB', 'FRAG',
25
+ 'WHADVP', 'NX', 'PRT', 'VBZ', 'VBP', 'MD', 'NN', 'WHPP', 'SQ', 'SBARQ', 'LST', 'INTJ', 'X', 'UCP', 'CONJP', 'NNP', 'CD', 'JJ',
26
+ 'VBD', 'WHADJP', 'PRP', 'RRC', 'NNS', 'SYM', 'CC']
27
+
28
  openai.api_key = "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"
29
 
30
 
31
  # determinant vs. determiner
32
  # https://wikidiff.com/determiner/determinant
33
+ ents_prompt = ['Noun','Verb','Adjective','Adverb','Preposition/Subord','Coordinating Conjunction',# 'Cardinal Number',
 
 
 
 
 
 
 
34
  'Determiner',
35
+ 'Noun Phrase','Verb Phrase','Adjective Phrase','Adverb Phrase','Preposition Phrase','Conjunction Phrase','Coordinate Phrase','Quantitave Phrase','Complex Nominal',
36
+ 'Clause','Dependent Clause','Fragment Clause','T-unit','Complex T-unit',# 'Fragment T-unit',
37
+ ][7:]
38
+ ents = ['NN', 'VB', 'JJ', 'RB', 'IN', 'CC', 'DT', 'NP', 'VP', 'ADJP', 'ADVP', 'PP', 'CONJP', 'CP', 'QP', 'CN', 'C', 'DC', 'FC', 'T', 'CT'][7:]
39
+
40
+
41
+ ents_prompt_uni_tags = ['Verb', 'Noun', 'Pronoun', 'Adjective', 'Adverb', 'Preposition and Postposition', 'Coordinating Conjunction',
42
+ 'Determiner', 'Cardinal Number', 'Particles or other function words',
43
+ 'Words that cannot be assigned a POS tag', 'Punctuation']
44
+
45
+ ents = uni_tags + ents
46
+ ents_prompt = ents_prompt_uni_tags + ents_prompt
47
+
48
+ for i, j in zip(ents, ents_prompt):
49
+ print(i, j)
50
+ # raise
 
51
 
52
 
53
  model_mapping = {
 
57
  'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
58
  'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
59
  'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
60
+ # 'llama2-7b': 'meta-llama/Llama-2-7b-hf',
61
+ # 'llama2-13b': 'meta-llama/Llama-2-13b-hf',
62
+ # 'llama2-70b': 'meta-llama/Llama-2-70b-hf',
63
+ 'llama-7b': './llama/hf/7B',
64
+ 'llama-13b': './llama/hf/13B',
65
+ 'llama-30b': './llama/hf/30B',
66
+ # 'llama-65b': './llama/hf/65B',
67
+ 'alpaca': './alpaca-7B',
68
  # 'koala-7b': 'koala-7b',
69
  # 'koala-13b': 'koala-13b',
70
  }
71
 
72
+ # for m in model_mapping.keys():
73
+ # for eid, ent in enumerate(ents):
74
+ # os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True)
75
+
76
+ # os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True)
77
+ # os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True)
78
+ # os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True)
79
+
80
+ # os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True)
81
+ # os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True)
82
+ # os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True)
83
 
84
 
85
  # s = int(sys.argv[1])
86
  # e = int(sys.argv[2])
87
 
88
+ # s = 0
89
+ # e = 1000
90
  with open('sample_uniform_1k_2.txt', 'r') as f:
91
  selected_idx = f.readlines()
92
+ selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
93
 
94
 
95
  ptb = []
96
+ with open('sample_uniform_1k_2.txt', 'r') as f:
97
  for l in f:
98
  ptb.append(json.loads(l))
99
 
 
103
  template_single = '''Please output any <{}> in the following sentence one per line without any additional text: "{}"'''
104
 
105
  ## Prompt 2
106
+ prompt2_pos = '''Please pos tag the following sentence using Universal POS tag set without generating any additional text: {}'''
107
+ prompt2_chunk = '''Please do sentence chunking for the following sentence as in CoNLL 2000 shared task without generating any addtional text: {}'''
108
+ prompt2_parse = '''Generate textual representation of the constituency parse tree of the following sentence using Penn TreeBank tag set without outputing any additional text: {}'''
109
+
110
+ prompt2_chunk = '''Please chunk the following sentence in CoNLL 2000 format with BIO tags without outputing any additional text: {}'''
111
+
112
+ ## Prompt 3
113
+ with open('demonstration_3_42_pos.txt', 'r') as f:
114
+ demon_pos = f.read()
115
  with open('demonstration_3_42_chunk.txt', 'r') as f:
116
+ demon_chunk = f.read()
117
+ with open('demonstration_3_42_parse.txt', 'r') as f:
118
+ demon_parse = f.read()
119
 
120
 
121
  def para(m):
 
126
 
127
  def main(args=None):
128
 
129
+ gid_list = selected_idx[args.start:args.end]
130
+
131
+
132
+ if 'gpt3' in args.model_path:
133
  pass
134
 
135
  else:
136
+ path = model_mapping[args.model_path]
137
  model, tokenizer = load_model(
138
  path,
139
  args.device,
 
145
  debug=args.debug,
146
  )
147
 
148
+ whitelist_ids_pos = [tokenizer.encode(word)[1] for word in uni_tags]
149
+ bad_words_ids_pos = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_pos]
150
+
151
+ whitelist_ids_bio = [tokenizer.encode(word)[1] for word in bio_tags]
152
+ bad_words_ids_bio = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_bio]
153
+
154
+ whitelist_ids_chunk = [tokenizer.encode(word)[1] for word in chunk_tags]
155
+ bad_words_ids_chunk = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_chunk]
156
+
157
+ whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags]
158
+ bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse]
159
+
160
+
161
  if args.prompt == 1:
162
+ for gid in tqdm(gid_list, desc='Query'):
163
  text = ptb[gid]['text']
164
 
165
  for eid, ent in enumerate(ents):
166
+ os.makedirs(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}', exist_ok=True)
167
+
168
+ if ent == 'NOUN' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN'):
169
+ os.system(f'ln -sT ./NN result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN')
170
+ if ent == 'VERB' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB'):
171
+ os.system(f'ln -sT ./VB result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB')
172
+ if ent == 'ADJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ'):
173
+ os.system(f'ln -sT ./JJ result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ')
174
+ if ent == 'ADV' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV'):
175
+ os.system(f'ln -sT ./RB result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV')
176
+ if ent == 'CONJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ'):
177
+ os.system(f'ln -sT ./CC result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ')
178
+ if ent == 'DET' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/DET'):
179
+ os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/DET')
180
+ if ent == 'ADP' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADP'):
181
+ os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/IN')
182
+
183
+ if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt'):
184
+ print(gid, ent, 'skip')
185
+ continue
186
+
187
 
188
  ## Get prompt
189
  msg = template_single.format(ents_prompt[eid], text)
190
 
191
+ ## Run
192
+ if 'gpt3' in args.model_path:
193
+ if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl'):
194
+ print('Found cache')
195
+ with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl', 'rb') as f:
196
+ outputs = pickle.load(f)
197
+ outputs = outputs['choices'][0]['message']['content']
198
+ else:
199
+ outputs = gpt3(msg)
200
+ if outputs is None:
201
+ continue
202
+ time.sleep(0.2)
203
+
204
+ else:
205
+ conv = get_conversation_template(args.model_path)
206
  conv.append_message(conv.roles[0], msg)
207
  conv.append_message(conv.roles[1], None)
208
  conv.system = ''
209
  prompt = conv.get_prompt().strip()
210
+ outputs = fastchat(prompt, model, tokenizer)
211
 
212
+ with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
213
+ f.write(outputs)
214
 
215
 
216
+ if args.prompt == 2:
217
+ for gid in tqdm(gid_list, desc='Query'):
218
+ text = ptb[gid]['text']
219
+
220
+ ## POS tagging
221
+ # if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
222
+ # print(gid, 'skip')
223
+
224
+ # else:
225
+ # msg = prompt2_pos.format(text)
226
+
227
+ # if 'gpt3' in args.model_path:
228
+ # outputs = gpt3(msg)
229
+ # if outputs is None:
230
+ # continue
231
+ # time.sleep(0.2)
232
+
233
+ # else:
234
+ # conv = get_conversation_template(args.model_path)
235
+ # conv.append_message(conv.roles[0], msg)
236
+ # conv.append_message(conv.roles[1], None)
237
+ # conv.system = ''
238
+ # prompt = conv.get_prompt()
239
+
240
+ # outputs = fastchat(prompt, model, tokenizer)
241
+
242
+ # with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
243
+ # f.write(outputs)
244
+
245
+
246
+ ## Sentence chunking
247
+ # if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'):
248
+ # print(gid, 'skip')
249
+ if False:
250
+ pass
251
+ else:
252
+ msg = prompt2_chunk.format(text)
253
+
254
+ if 'gpt3' in args.model_path:
255
+ outputs = gpt3(msg)
256
+ if outputs is None:
257
+ continue
258
+ time.sleep(0.2)
259
+
260
  else:
261
+ conv = get_conversation_template(args.model_path)
262
+ conv.append_message(conv.roles[0], msg)
263
+ conv.append_message(conv.roles[1], None)
264
+ conv.system = ''
265
+ prompt = conv.get_prompt()
266
+
267
  outputs = fastchat(prompt, model, tokenizer)
268
 
269
+ print(args.model_path, gid, outputs)
270
+ with open(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
271
  f.write(outputs)
272
 
273
+
274
+ ## Parsing
275
+ # if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'):
276
+ # print(gid, 'skip')
277
+
278
+ # else:
279
+ # msg = prompt2_parse.format(text)
280
 
281
+ # if 'gpt3' in args.model_path:
282
+ # outputs = gpt3(msg)
283
+ # if outputs is None:
284
+ # continue
285
+ # time.sleep(0.2)
286
+
287
+ # else:
288
+ # conv = get_conversation_template(args.model_path)
289
+ # conv.append_message(conv.roles[0], msg)
290
+ # conv.append_message(conv.roles[1], None)
291
+ # conv.system = ''
292
+ # prompt = conv.get_prompt()
293
+
294
+ # outputs = fastchat(prompt, model, tokenizer)
295
+
296
+ # with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
297
+ # f.write(outputs)
298
+
299
+
300
+
301
+ if args.prompt == 3:
302
+ for gid in tqdm(gid_list, desc='Query'):
303
  text = ptb[gid]['text']
304
+ tokens = ptb[gid]['tokens']
305
+ poss = ptb[gid]['uni_poss']
306
+
307
+ ## POS tagging
308
+ # if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
309
+ # print(gid, 'skip')
310
+ # continue
311
 
312
+ # prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: '
313
+
314
+ # if 'gpt3' in args.model_path:
315
+ # outputs = gpt3(prompt)
316
+ # if outputs is None:
317
+ # continue
318
+ # time.sleep(0.2)
319
+
320
+ # else:
321
+ # pred_poss = []
322
+ # for _tok, _pos in zip(tokens, poss):
323
+ # prompt = prompt + ' ' + _tok + '_'
324
+ # outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos)
325
+ # prompt = prompt + outputs
326
+ # pred_poss.append(outputs)
327
+
328
+ # outputs = ' '.join(pred_poss)
329
+ # with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
330
+ # f.write(outputs)
331
+
332
+
333
+ ## Chunking
334
+ if os.path.exists(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt'):
335
  print(gid, 'skip')
336
  continue
337
 
338
+ prompt = demon_chunk + '\n' + 'C: ' + text + '\n' + 'T: '
339
 
340
+ if 'gpt3' in args.model_path:
341
  outputs = gpt3(prompt)
342
+ print(outputs)
343
+ if outputs is None:
344
+ continue
345
+ time.sleep(0.2)
346
 
347
  else:
348
+ pred_chunk = []
349
+ for _tok, _pos in zip(tokens, poss):
350
+ prompt = prompt + ' ' + _tok + '_'
351
 
352
+ # Generate BIO
353
+ outputs_bio = structured_prompt(prompt, model, tokenizer, bad_words_ids_bio)
354
+ prompt = prompt + outputs_bio + '-'
355
+
356
+ # Generate tag
357
+ outputs_chunk = structured_prompt(prompt, model, tokenizer, bad_words_ids_chunk)
358
+ prompt = prompt + outputs_chunk
359
+
360
+ pred_chunk.append((outputs_bio + '-' + outputs_chunk))
361
+
362
+ outputs = ' '.join(pred_chunk)
363
+
364
+ with open(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
365
  f.write(outputs)
366
 
367
+ ## Parsing
368
+ # if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'):
369
+ # print(gid, 'skip')
370
+ # continue
371
+
372
+ # prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: '
373
+
374
+ # if 'gpt3' in args.model_path:
375
+ # outputs = gpt3(prompt)
376
+ # if outputs is None:
377
+ # continue
378
+ # time.sleep(0.2)
379
+
380
+ # else:
381
+ # pred_syn = []
382
+ # for _tok, _pos in zip(tokens, poss):
383
+ # prompt = prompt + _tok + '_'
384
+ # outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse)
385
+ # pred_syn.append(outputs)
386
+
387
+ # with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
388
+ # f.write(' '.join(pred_syn))
389
+
390
+
391
+ def structured_prompt(prompt, model, tokenizer, bad_words_ids):
392
+ input_ids = tokenizer([prompt]).input_ids
393
+ output_ids = model.generate(
394
+ torch.as_tensor(input_ids).cuda(),
395
+ max_new_tokens=1,
396
+ bad_words_ids=bad_words_ids,
397
+ )
398
+
399
+ if model.config.is_encoder_decoder:
400
+ output_ids = output_ids[0]
401
+ else:
402
+ output_ids = output_ids[0][len(input_ids[0]) :]
403
+ outputs = tokenizer.decode(
404
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
405
+ )
406
+
407
+ return outputs
408
+
409
 
410
  def fastchat(prompt, model, tokenizer):
411
  input_ids = tokenizer([prompt]).input_ids
 
435
  def gpt3(prompt):
436
  try:
437
  response = openai.ChatCompletion.create(
438
+ model=model_mapping[args.model_path], messages=[{"role": "user", "content": prompt}])
439
 
440
+ return response['choices'][0]['message']['content']
441
 
442
  except Exception as err:
443
  print('Error')
444
  print(err)
445
 
446
+ return None
 
447
 
448
 
449
  if __name__ == "__main__":
 
455
  parser.add_argument("--debug", action="store_true")
456
  parser.add_argument("--message", type=str, default="Hello! Who are you?")
457
  parser.add_argument("--start", type=int, default=0)
458
+ parser.add_argument("--end", type=int, default=1000)
 
459
  parser.add_argument("--prompt", required=True, type=int, default=None)
460
+ # parser.add_argument("--system_msg", required=True, type=str, default='default_system_msg')
461
  args = parser.parse_args()
462
 
463
  # Reset default repetition penalty for T5 models.
464
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
465
  args.repetition_penalty = 1.2
466
 
467
  main(args)