research14 commited on
Commit
6525dcf
·
1 Parent(s): bfc086f

Edits applied

Browse files
Files changed (2) hide show
  1. app.py +2 -46
  2. run_llm.py +370 -73
app.py CHANGED
@@ -54,7 +54,7 @@ model_mapping = {
54
  #'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  #'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  #'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
- 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
58
  #'llama-7b': './llama/hf/7B',
59
  #'llama-13b': './llama/hf/13B',
60
  #'llama-30b': './llama/hf/30B',
@@ -98,7 +98,7 @@ gpt_pipeline = pipeline(task="text-generation", model="gpt2")
98
  #vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
99
  #vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
100
  #vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
101
- fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
102
  #llama7b_pipeline = pipeline(task="text2text-generation", model="./llama/hf/7B")
103
  #llama13b_pipeline = pipeline(task="text2text-generation", model="./llama/hf/13B")
104
  #llama30b_pipeline = pipeline(task="text2text-generation", model="./llama/hf/30B")
@@ -145,50 +145,6 @@ def process_text(model_name, task, text):
145
  response2 = gpt_pipeline(strategy2)[0]['generated_text']
146
  response3 = gpt_pipeline(strategy3)[0]['generated_text']
147
  return (response1, response2, response3)
148
- elif model_name == 'fastchat-t5':
149
- if task == 'POS':
150
- strategy1 = template_all.format(text)
151
- strategy2 = prompt2_pos.format(text)
152
- strategy3 = demon_pos
153
-
154
- response1 = fastchatT5_pipeline(strategy1)[0]['generated_text']
155
- response2 = fastchatT5_pipeline(strategy2)[0]['generated_text']
156
- response3 = fastchatT5_pipeline(strategy3)[0]['generated_text']
157
- return (response1, response2, response3)
158
- elif task == 'Chunking':
159
- strategy1 = template_all.format(text)
160
- strategy2 = prompt2_chunk.format(text)
161
- strategy3 = demon_chunk
162
-
163
- response1 = fastchatT5_pipeline(strategy1)[0]['generated_text']
164
- response2 = fastchatT5_pipeline(strategy2)[0]['generated_text']
165
- response3 = fastchatT5_pipeline(strategy3)[0]['generated_text']
166
- return (response1, response2, response3)
167
- elif task == 'Parsing':
168
- strategy1 = template_all.format(text)
169
- strategy2 = prompt2_parse.format(text)
170
- strategy3 = demon_parse
171
-
172
- response1 = fastchatT5_pipeline(strategy1)[0]['generated_text']
173
- response2 = fastchatT5_pipeline(strategy2)[0]['generated_text']
174
- response3 = fastchatT5_pipeline(strategy3)[0]['generated_text']
175
- return (response1, response2, response3)
176
-
177
- # Define prompts for each strategy based on the task
178
- #strategy_prompts = {
179
- # 'Strategy 1': template_all.format(text),
180
- # 'Strategy 2': {
181
- # 'POS': prompt2_pos.format(text),
182
- # 'Chunking': prompt2_chunk.format(text),
183
- # 'Parsing': prompt2_parse.format(text),
184
- # }.get(task, "Invalid Task Selection for Strategy 2"),
185
- # 'Strategy 3': {
186
- # 'POS': demon_pos,
187
- # 'Chunking': demon_chunk,
188
- # 'Parsing': demon_parse,
189
- # }.get(task, "Invalid Task Selection for Strategy 3"),
190
- #}
191
-
192
 
193
  # Gradio interface
194
  iface = gr.Interface(
 
54
  #'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  #'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  #'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
+ #'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
58
  #'llama-7b': './llama/hf/7B',
59
  #'llama-13b': './llama/hf/13B',
60
  #'llama-30b': './llama/hf/30B',
 
98
  #vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
99
  #vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
100
  #vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
101
+ #fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
102
  #llama7b_pipeline = pipeline(task="text2text-generation", model="./llama/hf/7B")
103
  #llama13b_pipeline = pipeline(task="text2text-generation", model="./llama/hf/13B")
104
  #llama30b_pipeline = pipeline(task="text2text-generation", model="./llama/hf/30B")
 
145
  response2 = gpt_pipeline(strategy2)[0]['generated_text']
146
  response3 = gpt_pipeline(strategy3)[0]['generated_text']
147
  return (response1, response2, response3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # Gradio interface
150
  iface = gr.Interface(
run_llm.py CHANGED
@@ -15,7 +15,6 @@ from fastchat.model import load_model, get_conversation_template, add_model_args
15
  from nltk.tag.mapping import _UNIVERSAL_TAGS
16
 
17
  import gradio as gr
18
- from transformers import pipeline
19
 
20
  uni_tags = list(_UNIVERSAL_TAGS)
21
  uni_tags[-1] = 'PUNC'
@@ -29,6 +28,7 @@ syntags = ['NP', 'S', 'VP', 'ADJP', 'ADVP', 'SBAR', 'TOP', 'PP', 'POS', 'NAC', "
29
 
30
  openai.api_key = "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"
31
 
 
32
  # determinant vs. determiner
33
  # https://wikidiff.com/determiner/determinant
34
  ents_prompt = ['Noun','Verb','Adjective','Adverb','Preposition/Subord','Coordinating Conjunction',# 'Cardinal Number',
@@ -48,23 +48,51 @@ ents_prompt = ents_prompt_uni_tags + ents_prompt
48
 
49
  for i, j in zip(ents, ents_prompt):
50
  print(i, j)
 
 
51
 
52
  model_mapping = {
 
53
  'gpt3.5': 'gpt-3.5-turbo-0613',
54
  'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
  'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
 
 
 
58
  'llama-7b': './llama/hf/7B',
59
  'llama-13b': './llama/hf/13B',
60
  'llama-30b': './llama/hf/30B',
 
61
  'alpaca': './alpaca-7B',
 
 
62
  }
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  with open('sample_uniform_1k_2.txt', 'r') as f:
65
  selected_idx = f.readlines()
66
  selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
67
 
 
68
  ptb = []
69
  with open('ptb.jsonl', 'r') as f:
70
  for l in f:
@@ -90,82 +118,351 @@ with open('demonstration_3_42_chunk.txt', 'r') as f:
90
  with open('demonstration_3_42_parse.txt', 'r') as f:
91
  demon_parse = f.read()
92
 
93
- # Your existing code
94
- theme = gr.themes.Soft()
95
 
96
- pipeline = pipeline(task="text-generation", model="lmsys/vicuna-7b-v1.3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- # Dropdown options for model and task
99
- model_options = list(model_mapping.keys())
100
- task_options = ['POS', 'Chunking', 'Parsing']
 
 
101
 
102
- # Function to process text based on model and task
103
- def process_text(model_name, task, text):
104
- gid_list = selected_idx[0:20]
 
 
 
105
 
106
- for gid in tqdm(gid_list, desc='Query'):
 
 
 
 
 
 
 
 
107
  text = ptb[gid]['text']
108
-
109
- #if model_name is 'gpt3.5': 'gpt-3.5-turbo-0613',
110
- #elif model_name is 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
111
- #elif model_name is 'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
112
- #elif model_name is 'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
113
- #elif model_name is 'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
114
- #elif model_name is 'llama-7b': './llama/hf/7B',
115
- #elif model_name is 'llama-13b': './llama/hf/13B',
116
- #elif model_name is 'llama-30b': './llama/hf/30B',
117
- #elif model_name is 'alpaca': './alpaca-7B',
118
-
119
- if task == 'POS':
120
- strategy1 = pipeline(template_all.format(text))
121
- strategy2 = pipeline(prompt2_pos.format(text))
122
- strategy3 = pipeline(demon_pos)
123
- return (strategy1, strategy2, strategy3)
124
- elif task == 'Chunking':
125
- strategy1 = pipeline(template_all.format(text))
126
- strategy2 = pipeline(prompt2_chunk.format(text))
127
- strategy3 = pipeline(demon_chunk)
128
- return (strategy1, strategy2, strategy3)
129
- elif task == 'Parsing':
130
- strategy1 = pipeline(template_all.format(text))
131
- strategy2 = pipeline(prompt2_parse.format(text))
132
- strategy3 = pipeline(demon_parse)
133
- return (strategy1, strategy2, strategy3)
134
-
135
- # Define prompts for each strategy based on the task
136
- #strategy_prompts = {
137
- # 'Strategy 1': template_all.format(text),
138
- # 'Strategy 2': {
139
- # 'POS': prompt2_pos.format(text),
140
- # 'Chunking': prompt2_chunk.format(text),
141
- # 'Parsing': prompt2_parse.format(text),
142
- # }.get(task, "Invalid Task Selection for Strategy 2"),
143
- # 'Strategy 3': {
144
- # 'POS': demon_pos,
145
- # 'Chunking': demon_chunk,
146
- # 'Parsing': demon_parse,
147
- # }.get(task, "Invalid Task Selection for Strategy 3"),
148
- #}
149
-
150
-
151
- # Gradio interface
152
- iface = gr.Interface(
153
- fn=process_text,
154
- inputs=[
155
- gr.Dropdown(model_options, label="Select Model"),
156
- gr.Dropdown(task_options, label="Select Task"),
157
- gr.Textbox(label="Input Text", placeholder="Enter the text to process..."),
158
- ],
159
- outputs=[
160
- gr.Textbox(label="Strategy 1 QA Result"),
161
- gr.Textbox(label="Strategy 2 Instruction Result"),
162
- gr.Textbox(label="Strategy 3 Structured Prompting Result"),
163
- ],
164
- title = "LLM Evaluator For Linguistic Scrutiny",
165
- theme = theme,
166
- live=False,
167
- )
168
-
169
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
171
 
 
 
15
  from nltk.tag.mapping import _UNIVERSAL_TAGS
16
 
17
  import gradio as gr
 
18
 
19
  uni_tags = list(_UNIVERSAL_TAGS)
20
  uni_tags[-1] = 'PUNC'
 
28
 
29
  openai.api_key = "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"
30
 
31
+
32
  # determinant vs. determiner
33
  # https://wikidiff.com/determiner/determinant
34
  ents_prompt = ['Noun','Verb','Adjective','Adverb','Preposition/Subord','Coordinating Conjunction',# 'Cardinal Number',
 
48
 
49
  for i, j in zip(ents, ents_prompt):
50
  print(i, j)
51
+ # raise
52
+
53
 
54
  model_mapping = {
55
+ # 'gpt3': 'gpt-3',
56
  'gpt3.5': 'gpt-3.5-turbo-0613',
57
  'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
58
  'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
59
  'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
60
  'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
61
+ # 'llama2-7b': 'meta-llama/Llama-2-7b-hf',
62
+ # 'llama2-13b': 'meta-llama/Llama-2-13b-hf',
63
+ # 'llama2-70b': 'meta-llama/Llama-2-70b-hf',
64
  'llama-7b': './llama/hf/7B',
65
  'llama-13b': './llama/hf/13B',
66
  'llama-30b': './llama/hf/30B',
67
+ # 'llama-65b': './llama/hf/65B',
68
  'alpaca': './alpaca-7B',
69
+ # 'koala-7b': 'koala-7b',
70
+ # 'koala-13b': 'koala-13b',
71
  }
72
 
73
+ for m in model_mapping.keys():
74
+ for eid, ent in enumerate(ents):
75
+ os.makedirs(f'result/prompt1_qa/{m}/ptb/per_ent/{ent}', exist_ok=True)
76
+
77
+ os.makedirs(f'result/prompt2_instruction/pos_tagging/{m}/ptb', exist_ok=True)
78
+ os.makedirs(f'result/prompt2_instruction/chunking/{m}/ptb', exist_ok=True)
79
+ os.makedirs(f'result/prompt2_instruction/parsing/{m}/ptb', exist_ok=True)
80
+
81
+ os.makedirs(f'result/prompt3_structured_prompt/pos_tagging/{m}/ptb', exist_ok=True)
82
+ os.makedirs(f'result/prompt3_structured_prompt/chunking/{m}/ptb', exist_ok=True)
83
+ os.makedirs(f'result/prompt3_structured_prompt/parsing/{m}/ptb', exist_ok=True)
84
+
85
+
86
+ #s = int(sys.argv[1])
87
+ #e = int(sys.argv[2])
88
+
89
+ #s = 0
90
+ #e = 1000
91
  with open('sample_uniform_1k_2.txt', 'r') as f:
92
  selected_idx = f.readlines()
93
  selected_idx = [int(i.strip()) for i in selected_idx]#[s:e]
94
 
95
+
96
  ptb = []
97
  with open('ptb.jsonl', 'r') as f:
98
  for l in f:
 
118
  with open('demonstration_3_42_parse.txt', 'r') as f:
119
  demon_parse = f.read()
120
 
 
 
121
 
122
+ def para(m):
123
+ c = 0
124
+ for n, p in m.named_parameters():
125
+ c += p.numel()
126
+ return c
127
+
128
+ def main(args=None):
129
+
130
+ gid_list = selected_idx[args.start:args.end]
131
+
132
+
133
+ if 'gpt3' in args.model_path:
134
+ pass
135
+
136
+ else:
137
+ path = model_mapping[args.model_path]
138
+ model, tokenizer = load_model(
139
+ path,
140
+ args.device,
141
+ args.num_gpus,
142
+ args.max_gpu_memory,
143
+ args.load_8bit,
144
+ args.cpu_offloading,
145
+ revision=args.revision,
146
+ debug=args.debug,
147
+ )
148
+
149
+ whitelist_ids_pos = [tokenizer.encode(word)[1] for word in uni_tags]
150
+ bad_words_ids_pos = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_pos]
151
+
152
+ whitelist_ids_bio = [tokenizer.encode(word)[1] for word in bio_tags]
153
+ bad_words_ids_bio = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_bio]
154
+
155
+ whitelist_ids_chunk = [tokenizer.encode(word)[1] for word in chunk_tags]
156
+ bad_words_ids_chunk = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_chunk]
157
+
158
+ whitelist_ids_parse = [tokenizer.encode(word)[1] for word in syntags]
159
+ bad_words_ids_parse = [[ids] for ids in range(tokenizer.vocab_size) if ids not in whitelist_ids_parse]
160
+
161
+
162
+ if args.prompt == 1:
163
+ for gid in tqdm(gid_list, desc='Query'):
164
+ text = ptb[gid]['text']
165
+
166
+ for eid, ent in enumerate(ents):
167
+ os.makedirs(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}', exist_ok=True)
168
+
169
+ if ent == 'NOUN' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN'):
170
+ os.system(f'ln -sT ./NN result/prompt1_qa/{args.model_path}/ptb/per_ent/NOUN')
171
+ if ent == 'VERB' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB'):
172
+ os.system(f'ln -sT ./VB result/prompt1_qa/{args.model_path}/ptb/per_ent/VERB')
173
+ if ent == 'ADJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ'):
174
+ os.system(f'ln -sT ./JJ result/prompt1_qa/{args.model_path}/ptb/per_ent/ADJ')
175
+ if ent == 'ADV' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV'):
176
+ os.system(f'ln -sT ./RB result/prompt1_qa/{args.model_path}/ptb/per_ent/ADV')
177
+ if ent == 'CONJ' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ'):
178
+ os.system(f'ln -sT ./CC result/prompt1_qa/{args.model_path}/ptb/per_ent/CONJ')
179
+ if ent == 'DET' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/DET'):
180
+ os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/DET')
181
+ if ent == 'ADP' and not os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/ADP'):
182
+ os.system(f'ln -sT ./DT result/prompt1_qa/{args.model_path}/ptb/per_ent/IN')
183
+
184
+ if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt'):
185
+ print(gid, ent, 'skip')
186
+ continue
187
+
188
+
189
+ ## Get prompt
190
+ msg = template_single.format(ents_prompt[eid], text)
191
+
192
+ ## Run
193
+ if 'gpt3' in args.model_path:
194
+ if os.path.exists(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl'):
195
+ print('Found cache')
196
+ with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.pkl', 'rb') as f:
197
+ outputs = pickle.load(f)
198
+ outputs = outputs['choices'][0]['message']['content']
199
+ else:
200
+ outputs = gpt3(msg)
201
+ if outputs is None:
202
+ continue
203
+ time.sleep(0.2)
204
+
205
+ else:
206
+ conv = get_conversation_template(args.model_path)
207
+ conv.append_message(conv.roles[0], msg)
208
+ conv.append_message(conv.roles[1], None)
209
+ conv.system = ''
210
+ prompt = conv.get_prompt().strip()
211
+ outputs = fastchat(prompt, model, tokenizer)
212
+
213
+ with open(f'result/prompt1_qa/{args.model_path}/ptb/per_ent/{ent}/{gid}.txt', 'w') as f:
214
+ f.write(outputs)
215
+
216
+
217
+ if args.prompt == 2:
218
+ for gid in tqdm(gid_list, desc='Query'):
219
+ text = ptb[gid]['text']
220
+
221
+ ## POS tagging
222
+ if os.path.exists(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
223
+ print(gid, 'skip')
224
+
225
+ else:
226
+ msg = prompt2_pos.format(text)
227
+
228
+ if 'gpt3' in args.model_path:
229
+ outputs = gpt3(msg)
230
+ if outputs is None:
231
+ continue
232
+ time.sleep(0.2)
233
+
234
+ else:
235
+ conv = get_conversation_template(args.model_path)
236
+ conv.append_message(conv.roles[0], msg)
237
+ conv.append_message(conv.roles[1], None)
238
+ conv.system = ''
239
+ prompt = conv.get_prompt()
240
+
241
+ outputs = fastchat(prompt, model, tokenizer)
242
+
243
+ with open(f'result/prompt2_instruction/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
244
+ f.write(outputs)
245
+
246
+
247
+ ## Sentence chunking
248
+ if os.path.exists(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt'):
249
+ print(gid, 'skip')
250
+ if False:
251
+ pass
252
+ else:
253
+ msg = prompt2_chunk.format(text)
254
+
255
+ if 'gpt3' in args.model_path:
256
+ outputs = gpt3(msg)
257
+ if outputs is None:
258
+ continue
259
+ time.sleep(0.2)
260
+
261
+ else:
262
+ conv = get_conversation_template(args.model_path)
263
+ conv.append_message(conv.roles[0], msg)
264
+ conv.append_message(conv.roles[1], None)
265
+ conv.system = ''
266
+ prompt = conv.get_prompt()
267
+
268
+ outputs = fastchat(prompt, model, tokenizer)
269
+
270
+ print(args.model_path, gid, outputs)
271
+ with open(f'result/prompt2_instruction/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
272
+ f.write(outputs)
273
+
274
+
275
+ ## Parsing
276
+ if os.path.exists(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt'):
277
+ print(gid, 'skip')
278
+
279
+ else:
280
+ msg = prompt2_parse.format(text)
281
 
282
+ if 'gpt3' in args.model_path:
283
+ outputs = gpt3(msg)
284
+ if outputs is None:
285
+ continue
286
+ time.sleep(0.2)
287
 
288
+ else:
289
+ conv = get_conversation_template(args.model_path)
290
+ conv.append_message(conv.roles[0], msg)
291
+ conv.append_message(conv.roles[1], None)
292
+ conv.system = ''
293
+ prompt = conv.get_prompt()
294
 
295
+ outputs = fastchat(prompt, model, tokenizer)
296
+
297
+ with open(f'result/prompt2_instruction/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
298
+ f.write(outputs)
299
+
300
+
301
+
302
+ if args.prompt == 3:
303
+ for gid in tqdm(gid_list, desc='Query'):
304
  text = ptb[gid]['text']
305
+ tokens = ptb[gid]['tokens']
306
+ poss = ptb[gid]['uni_poss']
307
+
308
+ ## POS tagging
309
+ if os.path.exists(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt'):
310
+ print(gid, 'skip')
311
+ continue
312
+
313
+ prompt = demon_pos + '\n' + 'C: ' + text + '\n' + 'T: '
314
+
315
+ if 'gpt3' in args.model_path:
316
+ outputs = gpt3(prompt)
317
+ if outputs is None:
318
+ continue
319
+ time.sleep(0.2)
320
+
321
+ else:
322
+ pred_poss = []
323
+ for _tok, _pos in zip(tokens, poss):
324
+ prompt = prompt + ' ' + _tok + '_'
325
+ outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_pos)
326
+ prompt = prompt + outputs
327
+ pred_poss.append(outputs)
328
+
329
+ outputs = ' '.join(pred_poss)
330
+ with open(f'result/prompt3_structured_prompt/pos_tagging/{args.model_path}/ptb/{gid}.txt', 'w') as f:
331
+ f.write(outputs)
332
+
333
+
334
+ ## Chunking
335
+ if os.path.exists(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt'):
336
+ print(gid, 'skip')
337
+ continue
338
+
339
+ prompt = demon_chunk + '\n' + 'C: ' + text + '\n' + 'T: '
340
+
341
+ if 'gpt3' in args.model_path:
342
+ outputs = gpt3(prompt)
343
+ print(outputs)
344
+ if outputs is None:
345
+ continue
346
+ time.sleep(0.2)
347
+
348
+ else:
349
+ pred_chunk = []
350
+ for _tok, _pos in zip(tokens, poss):
351
+ prompt = prompt + ' ' + _tok + '_'
352
+
353
+ # Generate BIO
354
+ outputs_bio = structured_prompt(prompt, model, tokenizer, bad_words_ids_bio)
355
+ prompt = prompt + outputs_bio + '-'
356
+
357
+ # Generate tag
358
+ outputs_chunk = structured_prompt(prompt, model, tokenizer, bad_words_ids_chunk)
359
+ prompt = prompt + outputs_chunk
360
+
361
+ pred_chunk.append((outputs_bio + '-' + outputs_chunk))
362
+
363
+ outputs = ' '.join(pred_chunk)
364
+
365
+ with open(f'result/prompt3_structured_prompt/chunking/{args.model_path}/ptb/{gid}.txt', 'w') as f:
366
+ f.write(outputs)
367
+
368
+ ## Parsing
369
+ if os.path.exists(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt'):
370
+ print(gid, 'skip')
371
+ continue
372
+
373
+ prompt = demon_parse + '\n' + 'C: ' + text + '\n' + 'T: '
374
+
375
+ if 'gpt3' in args.model_path:
376
+ outputs = gpt3(prompt)
377
+ if outputs is None:
378
+ continue
379
+ time.sleep(0.2)
380
+
381
+ else:
382
+ pred_syn = []
383
+ for _tok, _pos in zip(tokens, poss):
384
+ prompt = prompt + _tok + '_'
385
+ outputs = structured_prompt(prompt, model, tokenizer, bad_words_ids_parse)
386
+ pred_syn.append(outputs)
387
+
388
+ with open(f'result/prompt3_structured_prompt/parsing/{args.model_path}/ptb/{gid}.txt', 'w') as f:
389
+ f.write(' '.join(pred_syn))
390
+
391
+
392
+ def structured_prompt(prompt, model, tokenizer, bad_words_ids):
393
+ input_ids = tokenizer([prompt]).input_ids
394
+ output_ids = model.generate(
395
+ torch.as_tensor(input_ids).cuda(),
396
+ max_new_tokens=1,
397
+ bad_words_ids=bad_words_ids,
398
+ )
399
+
400
+ if model.config.is_encoder_decoder:
401
+ output_ids = output_ids[0]
402
+ else:
403
+ output_ids = output_ids[0][len(input_ids[0]) :]
404
+ outputs = tokenizer.decode(
405
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
406
+ )
407
+
408
+ return outputs
409
+
410
+
411
+ def fastchat(prompt, model, tokenizer):
412
+ input_ids = tokenizer([prompt]).input_ids
413
+ output_ids = model.generate(
414
+ torch.as_tensor(input_ids).cuda(),
415
+ do_sample=True,
416
+ temperature=args.temperature,
417
+ repetition_penalty=args.repetition_penalty,
418
+ max_new_tokens=args.max_new_tokens,
419
+ )
420
+
421
+ if model.config.is_encoder_decoder:
422
+ output_ids = output_ids[0]
423
+ else:
424
+ output_ids = output_ids[0][len(input_ids[0]) :]
425
+ outputs = tokenizer.decode(
426
+ output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
427
+ )
428
+
429
+ #print('Empty system message')
430
+ #print(f"{conv.roles[0]}: {msg}")
431
+ #print(f"{conv.roles[1]}: {outputs}")
432
+
433
+ return outputs
434
+
435
+
436
+ def gpt3(prompt):
437
+ try:
438
+ response = openai.ChatCompletion.create(
439
+ model=model_mapping[args.model_path], messages=[{"role": "user", "content": prompt}])
440
+
441
+ return response['choices'][0]['message']['content']
442
+
443
+ except Exception as err:
444
+ print('Error')
445
+ print(err)
446
+
447
+ return None
448
+
449
+
450
+ if __name__ == "__main__":
451
+ parser = argparse.ArgumentParser()
452
+ add_model_args(parser)
453
+ parser.add_argument("--temperature", type=float, default=0.7)
454
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
455
+ parser.add_argument("--max-new-tokens", type=int, default=512)
456
+ parser.add_argument("--debug", action="store_true")
457
+ parser.add_argument("--message", type=str, default="Hello! Who are you?")
458
+ parser.add_argument("--start", type=int, default=0)
459
+ parser.add_argument("--end", type=int, default=1000)
460
+ parser.add_argument("--prompt", required=True, type=int, default=None)
461
+ # parser.add_argument("--system_msg", required=True, type=str, default='default_system_msg')
462
+ args = parser.parse_args()
463
 
464
+ # Reset default repetition penalty for T5 models.
465
+ if "t5" in args.model_path and args.repetition_penalty == 1.0:
466
+ args.repetition_penalty = 1.2
467
 
468
+ main(args)