Spaces:
Sleeping
Sleeping
Commit
·
0f3e715
1
Parent(s):
39fb3bc
Improve inputs
Browse files
app.py
CHANGED
@@ -13,13 +13,7 @@ def reduce_mean(value, mask, axis=None):
|
|
13 |
|
14 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
15 |
|
16 |
-
max_input_len = 256
|
17 |
-
max_output_len = 32
|
18 |
-
m = 10
|
19 |
-
top_p = 0.5
|
20 |
-
|
21 |
class InteractiveRainier:
|
22 |
-
|
23 |
def __init__(self):
|
24 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
|
25 |
self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
|
@@ -46,7 +40,7 @@ class InteractiveRainier:
|
|
46 |
choices.append(choice)
|
47 |
return choices
|
48 |
|
49 |
-
def run(self, question):
|
50 |
tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
|
51 |
knowledges_ids = self.rainier_model.generate(
|
52 |
input_ids=tokenized.input_ids,
|
@@ -107,8 +101,8 @@ class InteractiveRainier:
|
|
107 |
|
108 |
rainier = InteractiveRainier()
|
109 |
|
110 |
-
def predict(question,
|
111 |
-
result = rainier.run(
|
112 |
output = ''
|
113 |
output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
|
114 |
output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
|
@@ -120,13 +114,29 @@ def predict(question, choices):
|
|
120 |
output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
|
121 |
return output
|
122 |
|
123 |
-
input_question = gr.inputs.
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
output_text = gr.outputs.Textbox(label='Output')
|
126 |
|
127 |
gr.Interface(
|
128 |
fn=predict,
|
129 |
-
inputs=[input_question,
|
130 |
outputs=output_text,
|
131 |
title="Rainier",
|
132 |
).launch()
|
|
|
13 |
|
14 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
class InteractiveRainier:
|
|
|
17 |
def __init__(self):
|
18 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
|
19 |
self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
|
|
|
40 |
choices.append(choice)
|
41 |
return choices
|
42 |
|
43 |
+
def run(self, question, max_input_len, max_output_len, m, top_p):
|
44 |
tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
|
45 |
knowledges_ids = self.rainier_model.generate(
|
46 |
input_ids=tokenized.input_ids,
|
|
|
101 |
|
102 |
rainier = InteractiveRainier()
|
103 |
|
104 |
+
def predict(question, kg_model, qa_model, max_input_len, max_output_len, m, top_p):
|
105 |
+
result = rainier.run(question, max_input_len, max_output_len, m, top_p)
|
106 |
output = ''
|
107 |
output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
|
108 |
output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
|
|
|
114 |
output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
|
115 |
return output
|
116 |
|
117 |
+
input_question = gr.inputs.Dropdown(
|
118 |
+
choices=[
|
119 |
+
'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
|
120 |
+
'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
|
121 |
+
'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
|
122 |
+
'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
|
123 |
+
'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
|
124 |
+
'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
|
125 |
+
],
|
126 |
+
label='Question:',
|
127 |
+
info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
|
128 |
+
)
|
129 |
+
input_kg_model = gr.inputs.Textbox(label='Knowledge generation model:', value='liujch1998/rainier-large', interactive=False)
|
130 |
+
input_qa_model = gr.inputs.Textbox(label='QA model:', value='allenai/unifiedqa-t5-large', interactive=False)
|
131 |
+
input_max_input_len = gr.inputs.Number(label='Max question length:', value=256, precision=0)
|
132 |
+
input_max_output_len = gr.inputs.Number(label='Max knowledge length:', value=32, precision=0)
|
133 |
+
input_m = gr.inputs.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1)
|
134 |
+
input_top_p = gr.inputs.Slider(label='Top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
|
135 |
output_text = gr.outputs.Textbox(label='Output')
|
136 |
|
137 |
gr.Interface(
|
138 |
fn=predict,
|
139 |
+
inputs=[input_question, input_kg_model, input_qa_model, input_max_input_len, input_max_output_len, input_m, input_top_p],
|
140 |
outputs=output_text,
|
141 |
title="Rainier",
|
142 |
).launch()
|