research14 commited on
Commit
5cb31ec
·
1 Parent(s): 6525dcf
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -50,8 +50,8 @@ for i, j in zip(ents, ents_prompt):
50
  print(i, j)
51
 
52
  model_mapping = {
53
- 'gpt3.5': 'gpt2',
54
- #'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  #'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  #'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
  #'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
@@ -94,8 +94,8 @@ with open('demonstration_3_42_parse.txt', 'r') as f:
94
  theme = gr.themes.Soft()
95
 
96
 
97
- gpt_pipeline = pipeline(task="text-generation", model="gpt2")
98
- #vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
99
  #vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
100
  #vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
101
  #fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
@@ -117,33 +117,33 @@ def process_text(model_name, task, text):
117
  for gid in tqdm(gid_list, desc='Query'):
118
  text = ptb[gid]['text']
119
 
120
- if model_name == 'gpt3.5':
121
  if task == 'POS':
122
  strategy1 = template_all.format(text)
123
  strategy2 = prompt2_pos.format(text)
124
  strategy3 = demon_pos
125
 
126
- response1 = gpt_pipeline(strategy1)[0]['generated_text']
127
- response2 = gpt_pipeline(strategy2)[0]['generated_text']
128
- response3 = gpt_pipeline(strategy3)[0]['generated_text']
129
  return (response1, response2, response3)
130
  elif task == 'Chunking':
131
  strategy1 = template_all.format(text)
132
  strategy2 = prompt2_chunk.format(text)
133
  strategy3 = demon_chunk
134
 
135
- response1 = gpt_pipeline(strategy1)[0]['generated_text']
136
- response2 = gpt_pipeline(strategy2)[0]['generated_text']
137
- response3 = gpt_pipeline(strategy3)[0]['generated_text']
138
  return (response1, response2, response3)
139
  elif task == 'Parsing':
140
  strategy1 = template_all.format(text)
141
  strategy2 = prompt2_parse.format(text)
142
  strategy3 = demon_parse
143
 
144
- response1 = gpt_pipeline(strategy1)[0]['generated_text']
145
- response2 = gpt_pipeline(strategy2)[0]['generated_text']
146
- response3 = gpt_pipeline(strategy3)[0]['generated_text']
147
  return (response1, response2, response3)
148
 
149
  # Gradio interface
 
50
  print(i, j)
51
 
52
  model_mapping = {
53
+ #'gpt3.5': 'gpt2',
54
+ 'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
55
  #'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
56
  #'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
57
  #'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
 
94
  theme = gr.themes.Soft()
95
 
96
 
97
+ #gpt_pipeline = pipeline(task="text-generation", model="gpt2")
98
+ vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
99
  #vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
100
  #vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
101
  #fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
 
117
  for gid in tqdm(gid_list, desc='Query'):
118
  text = ptb[gid]['text']
119
 
120
+ if model_name == 'vicuna-7b':
121
  if task == 'POS':
122
  strategy1 = template_all.format(text)
123
  strategy2 = prompt2_pos.format(text)
124
  strategy3 = demon_pos
125
 
126
+ response1 = vicuna7b_pipeline(strategy1)[0]['generated_text']
127
+ response2 = vicuna7b_pipeline(strategy2)[0]['generated_text']
128
+ response3 = vicuna7b_pipeline(strategy3)[0]['generated_text']
129
  return (response1, response2, response3)
130
  elif task == 'Chunking':
131
  strategy1 = template_all.format(text)
132
  strategy2 = prompt2_chunk.format(text)
133
  strategy3 = demon_chunk
134
 
135
+ response1 = vicuna7b_pipeline(strategy1)[0]['generated_text']
136
+ response2 = vicuna7b_pipeline(strategy2)[0]['generated_text']
137
+ response3 = vicuna7b_pipeline(strategy3)[0]['generated_text']
138
  return (response1, response2, response3)
139
  elif task == 'Parsing':
140
  strategy1 = template_all.format(text)
141
  strategy2 = prompt2_parse.format(text)
142
  strategy3 = demon_parse
143
 
144
+ response1 = vicuna7b_pipeline(strategy1)[0]['generated_text']
145
+ response2 = vicuna7b_pipeline(strategy2)[0]['generated_text']
146
+ response3 = vicuna7b_pipeline(strategy3)[0]['generated_text']
147
  return (response1, response2, response3)
148
 
149
  # Gradio interface