research14 commited on
Commit
816d981
·
1 Parent(s): 227a8a7

updated with vicuna

Browse files
Files changed (1) hide show
  1. app.py +26 -45
app.py CHANGED
@@ -50,8 +50,8 @@ for i, j in zip(ents, ents_prompt):
50
  print(i, j)
51
 
52
  model_mapping = {
53
- 'gpt3.5': 'gpt2',
54
- #'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  #'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  #'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
  #'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
@@ -94,8 +94,8 @@ with open('demonstration_3_42_parse.txt', 'r') as f:
94
  theme = gr.themes.Soft()
95
 
96
 
97
- gpt_pipeline = pipeline(task="text2text-generation", model="gpt2")
98
- #vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
99
  #vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
100
  #vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
101
  #fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
@@ -116,53 +116,34 @@ def process_text(model_name, task, text):
116
  for gid in tqdm(gid_list, desc='Query'):
117
  text = ptb[gid]['text']
118
 
119
- if model_name == 'gpt3.5':
120
  if task == 'POS':
121
- from transformers import pipeline
122
-
123
- # Load the pipeline for part-of-speech tagging using the vicuna7b model
124
- pos_pipeline = pipeline("ner", model="finiteautomata/bert2crf4ner-pos", tokenizer="finiteautomata/bert2crf4ner-pos")
125
-
126
- # Define the input phrase
127
- input_phrase = "A cat loves the beautiful dog in the neighbor's house"
128
-
129
- # Use the pipeline to get the part-of-speech tags
130
- pos_tags = pos_pipeline(input_phrase)
131
-
132
- # Print the output
133
- print(pos_tags)
134
-
135
  strategy1_format = template_all.format(text)
136
  strategy2_format = prompt2_pos.format(text)
137
  strategy3_format = demon_pos
138
 
139
- result1 = gpt_pipeline(strategy1_format)
140
- result2 = gpt_pipeline(strategy2_format)
141
- result3 = gpt_pipeline(strategy3_format)
142
-
143
- generated_text1 = result1[0]['generated_text']
144
- generated_text2 = result2[0]['generated_text']
145
- generated_text3 = result3[0]['generated_text']
146
-
147
- return (generated_text1, generated_text2, generated_text3)
148
- # elif task == 'Chunking':
149
- # strategy1_format = template_all.format(text)
150
- # strategy2_format = prompt2_chunk.format(text)
151
- # strategy3_format = demon_chunk
152
-
153
- # result1 = gpt_pipeline(strategy1_format)[0]['generated_text']
154
- # result2 = gpt_pipeline(strategy2_format)[0]['generated_text']
155
- # result3 = gpt_pipeline(strategy3_format)[0]['generated_text']
156
- # return (result1, result2, result3)
157
- # elif task == 'Parsing':
158
- # strategy1_format = template_all.format(text)
159
- # strategy2_format = prompt2_parse.format(text)
160
- # strategy3_format = demon_parse
161
 
162
- # result1 = gpt_pipeline(strategy1_format)[0]['generated_text']
163
- # result2 = gpt_pipeline(strategy2_format)[0]['generated_text']
164
- # result3 = gpt_pipeline(strategy3_format)[0]['generated_text']
165
- # return (result1, result2, result3)
166
 
167
  # Gradio interface
168
  iface = gr.Interface(
 
50
  print(i, j)
51
 
52
  model_mapping = {
53
+ #'gpt3.5': 'gpt2',
54
+ 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  #'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  #'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
  #'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
 
94
  theme = gr.themes.Soft()
95
 
96
 
97
+ #gpt_pipeline = pipeline(task="text2text-generation", model="gpt2")
98
+ vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
99
  #vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
100
  #vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
101
  #fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
 
116
  for gid in tqdm(gid_list, desc='Query'):
117
  text = ptb[gid]['text']
118
 
119
+ if model_name == 'vicuna-7b':
120
  if task == 'POS':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  strategy1_format = template_all.format(text)
122
  strategy2_format = prompt2_pos.format(text)
123
  strategy3_format = demon_pos
124
 
125
+ result1 = vicuna7b_pipeline(strategy1_format)[0]['generated_text']
126
+ result2 = vicuna7b_pipeline(strategy2_format)[0]['generated_text']
127
+ result3 = vicuna7b_pipeline(strategy3_format)[0]['generated_text']
128
+ return (result1, result2, result3)
129
+ elif task == 'Chunking':
130
+ strategy1_format = template_all.format(text)
131
+ strategy2_format = prompt2_chunk.format(text)
132
+ strategy3_format = demon_chunk
133
+
134
+ result1 = vicuna7b_pipeline(strategy1_format)[0]['generated_text']
135
+ result2 = vicuna7b_pipeline(strategy2_format)[0]['generated_text']
136
+ result3 = vicuna7b_pipeline(strategy3_format)[0]['generated_text']
137
+ return (result1, result2, result3)
138
+ elif task == 'Parsing':
139
+ strategy1_format = template_all.format(text)
140
+ strategy2_format = prompt2_parse.format(text)
141
+ strategy3_format = demon_parse
 
 
 
 
 
142
 
143
+ result1 = vicuna7b_pipeline(strategy1_format)[0]['generated_text']
144
+ result2 = vicuna7b_pipeline(strategy2_format)[0]['generated_text']
145
+ result3 = vicuna7b_pipeline(strategy3_format)[0]['generated_text']
146
+ return (result1, result2, result3)
147
 
148
  # Gradio interface
149
  iface = gr.Interface(