Spaces:
Runtime error
Runtime error
Commit
·
5cb31ec
1
Parent(s):
6525dcf
test
Browse files
app.py
CHANGED
@@ -50,8 +50,8 @@ for i, j in zip(ents, ents_prompt):
|
|
50 |
print(i, j)
|
51 |
|
52 |
model_mapping = {
|
53 |
-
'gpt3.5': 'gpt2',
|
54 |
-
|
55 |
#'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
|
56 |
#'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
|
57 |
#'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
|
@@ -94,8 +94,8 @@ with open('demonstration_3_42_parse.txt', 'r') as f:
|
|
94 |
theme = gr.themes.Soft()
|
95 |
|
96 |
|
97 |
-
gpt_pipeline = pipeline(task="text-generation", model="gpt2")
|
98 |
-
|
99 |
#vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
|
100 |
#vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
|
101 |
#fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
|
@@ -117,33 +117,33 @@ def process_text(model_name, task, text):
|
|
117 |
for gid in tqdm(gid_list, desc='Query'):
|
118 |
text = ptb[gid]['text']
|
119 |
|
120 |
-
if model_name == '
|
121 |
if task == 'POS':
|
122 |
strategy1 = template_all.format(text)
|
123 |
strategy2 = prompt2_pos.format(text)
|
124 |
strategy3 = demon_pos
|
125 |
|
126 |
-
response1 =
|
127 |
-
response2 =
|
128 |
-
response3 =
|
129 |
return (response1, response2, response3)
|
130 |
elif task == 'Chunking':
|
131 |
strategy1 = template_all.format(text)
|
132 |
strategy2 = prompt2_chunk.format(text)
|
133 |
strategy3 = demon_chunk
|
134 |
|
135 |
-
response1 =
|
136 |
-
response2 =
|
137 |
-
response3 =
|
138 |
return (response1, response2, response3)
|
139 |
elif task == 'Parsing':
|
140 |
strategy1 = template_all.format(text)
|
141 |
strategy2 = prompt2_parse.format(text)
|
142 |
strategy3 = demon_parse
|
143 |
|
144 |
-
response1 =
|
145 |
-
response2 =
|
146 |
-
response3 =
|
147 |
return (response1, response2, response3)
|
148 |
|
149 |
# Gradio interface
|
|
|
50 |
print(i, j)
|
51 |
|
52 |
model_mapping = {
|
53 |
+
#'gpt3.5': 'gpt2',
|
54 |
+
'vicuna-7b': 'lmsys/vicuna-7b-v1.3',
|
55 |
#'vicuna-13b': 'lmsys/vicuna-13b-v1.3',
|
56 |
#'vicuna-33b': 'lmsys/vicuna-33b-v1.3',
|
57 |
#'fastchat-t5': 'lmsys/fastchat-t5-3b-v1.0',
|
|
|
94 |
theme = gr.themes.Soft()
|
95 |
|
96 |
|
97 |
+
#gpt_pipeline = pipeline(task="text-generation", model="gpt2")
|
98 |
+
vicuna7b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-7b-v1.3")
|
99 |
#vicuna13b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-13b-v1.3")
|
100 |
#vicuna33b_pipeline = pipeline(task="text2text-generation", model="lmsys/vicuna-33b-v1.3")
|
101 |
#fastchatT5_pipeline = pipeline(task="text2text-generation", model="lmsys/fastchat-t5-3b-v1.0")
|
|
|
117 |
for gid in tqdm(gid_list, desc='Query'):
|
118 |
text = ptb[gid]['text']
|
119 |
|
120 |
+
if model_name == 'vicuna-7b':
|
121 |
if task == 'POS':
|
122 |
strategy1 = template_all.format(text)
|
123 |
strategy2 = prompt2_pos.format(text)
|
124 |
strategy3 = demon_pos
|
125 |
|
126 |
+
response1 = vicuna7b_pipeline(strategy1)[0]['generated_text']
|
127 |
+
response2 = vicuna7b_pipeline(strategy2)[0]['generated_text']
|
128 |
+
response3 = vicuna7b_pipeline(strategy3)[0]['generated_text']
|
129 |
return (response1, response2, response3)
|
130 |
elif task == 'Chunking':
|
131 |
strategy1 = template_all.format(text)
|
132 |
strategy2 = prompt2_chunk.format(text)
|
133 |
strategy3 = demon_chunk
|
134 |
|
135 |
+
response1 = vicuna7b_pipeline(strategy1)[0]['generated_text']
|
136 |
+
response2 = vicuna7b_pipeline(strategy2)[0]['generated_text']
|
137 |
+
response3 = vicuna7b_pipeline(strategy3)[0]['generated_text']
|
138 |
return (response1, response2, response3)
|
139 |
elif task == 'Parsing':
|
140 |
strategy1 = template_all.format(text)
|
141 |
strategy2 = prompt2_parse.format(text)
|
142 |
strategy3 = demon_parse
|
143 |
|
144 |
+
response1 = vicuna7b_pipeline(strategy1)[0]['generated_text']
|
145 |
+
response2 = vicuna7b_pipeline(strategy2)[0]['generated_text']
|
146 |
+
response3 = vicuna7b_pipeline(strategy3)[0]['generated_text']
|
147 |
return (response1, response2, response3)
|
148 |
|
149 |
# Gradio interface
|