liujch1998 commited on
Commit
0f3e715
·
1 Parent(s): 39fb3bc

Improve inputs

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -13,13 +13,7 @@ def reduce_mean(value, mask, axis=None):
13
 
14
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
15
 
16
- max_input_len = 256
17
- max_output_len = 32
18
- m = 10
19
- top_p = 0.5
20
-
21
  class InteractiveRainier:
22
-
23
  def __init__(self):
24
  self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
25
  self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
@@ -46,7 +40,7 @@ class InteractiveRainier:
46
  choices.append(choice)
47
  return choices
48
 
49
- def run(self, question):
50
  tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
51
  knowledges_ids = self.rainier_model.generate(
52
  input_ids=tokenized.input_ids,
@@ -107,8 +101,8 @@ class InteractiveRainier:
107
 
108
  rainier = InteractiveRainier()
109
 
110
- def predict(question, choices):
111
- result = rainier.run(f'{question} \\n {choices}')
112
  output = ''
113
  output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
114
  output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
@@ -120,13 +114,29 @@ def predict(question, choices):
120
  output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
121
  return output
122
 
123
- input_question = gr.inputs.Textbox(label='Question:')
124
- input_choices = gr.inputs.Textbox(label='Choices:')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  output_text = gr.outputs.Textbox(label='Output')
126
 
127
  gr.Interface(
128
  fn=predict,
129
- inputs=[input_question, input_choices],
130
  outputs=output_text,
131
  title="Rainier",
132
  ).launch()
 
13
 
14
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
15
 
 
 
 
 
 
16
  class InteractiveRainier:
 
17
  def __init__(self):
18
  self.tokenizer = transformers.AutoTokenizer.from_pretrained('allenai/unifiedqa-t5-large')
19
  self.rainier_model = transformers.AutoModelForSeq2SeqLM.from_pretrained('liujch1998/rainier-large').to(device)
 
40
  choices.append(choice)
41
  return choices
42
 
43
+ def run(self, question, max_input_len, max_output_len, m, top_p):
44
  tokenized = self.tokenizer(question, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_input_len).to(device) # (1, L)
45
  knowledges_ids = self.rainier_model.generate(
46
  input_ids=tokenized.input_ids,
 
101
 
102
  rainier = InteractiveRainier()
103
 
104
+ def predict(question, kg_model, qa_model, max_input_len, max_output_len, m, top_p):
105
+ result = rainier.run(question, max_input_len, max_output_len, m, top_p)
106
  output = ''
107
  output += f'QA model answer without knowledge: {result["knowless_pred"]}\n'
108
  output += f'QA model answer with knowledge: {result["knowful_pred"]}\n'
 
114
  output += f'Knowledge selected to make the prediction: {result["selected_knowledge"]}\n'
115
  return output
116
 
117
+ input_question = gr.inputs.Dropdown(
118
+ choices=[
119
+ 'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
120
+ 'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
121
+ 'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
122
+ 'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
123
+ 'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
124
+ 'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
125
+ ],
126
+ label='Question:',
127
+ info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
128
+ )
129
+ input_kg_model = gr.inputs.Textbox(label='Knowledge generation model:', value='liujch1998/rainier-large', interactive=False)
130
+ input_qa_model = gr.inputs.Textbox(label='QA model:', value='allenai/unifiedqa-t5-large', interactive=False)
131
+ input_max_input_len = gr.inputs.Number(label='Max question length:', value=256, precision=0)
132
+ input_max_output_len = gr.inputs.Number(label='Max knowledge length:', value=32, precision=0)
133
+ input_m = gr.inputs.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1)
134
+ input_top_p = gr.inputs.Slider(label='Top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
135
  output_text = gr.outputs.Textbox(label='Output')
136
 
137
  gr.Interface(
138
  fn=predict,
139
+ inputs=[input_question, input_kg_model, input_qa_model, input_max_input_len, input_max_output_len, input_m, input_top_p],
140
  outputs=output_text,
141
  title="Rainier",
142
  ).launch()