flyboytarantino14 commited on
Commit
bf035c8
·
1 Parent(s): 4dbb7b0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +43 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
5
+ #question_model = T5ForConditionalGeneration.from_pretrained('ramsrigouthamg/t5_squad_v1')
6
+ #question_tokenizer = T5Tokenizer.from_pretrained('t5-base')
7
+ question_model = T5ForConditionalGeneration.from_pretrained('t5-small')
8
+ question_tokenizer = T5Tokenizer.from_pretrained('t5-small')
9
+
10
+ def get_question(context, answer):
11
+ text = "context: {} answer: {}".format(context, answer)
12
+ #max_len = 512
13
+ #encoding = question_tokenizer.encode_plus(text, max_length=max_len, padding='max_length', truncation=True, return_tensors="pt")
14
+ encoding = question_tokenizer.encode_plus(text, return_tensors="pt")
15
+ input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
16
+ outs = question_model.generate(input_ids=input_ids,
17
+ attention_mask=attention_mask,
18
+ early_stopping=True,
19
+ num_beams=3, # Use fewer beams to generate fewer but higher-quality questions
20
+ num_return_sequences=3,
21
+ no_repeat_ngram_size=3, # Allow some repetition to avoid generating nonsensical questions
22
+ max_length=256) # Use a shorter max length to focus on generating more relevant questions
23
+ dec = [question_tokenizer.decode(ids) for ids in outs]
24
+ questions = ""
25
+ for i, question in enumerate(dec):
26
+ question = question.replace("question:", "").replace("<pad>", "").replace("</s>", "")
27
+ question = question.strip()
28
+ questions += question
29
+ if i != len(dec)-1:
30
+ questions += "§"
31
+ return questions
32
+
33
+ input_context = gr.Textbox()
34
+ input_answer = gr.Textbox()
35
+ output_question = gr.Textbox()
36
+
37
+ interface = gr.Interface(
38
+ fn=get_question,
39
+ inputs=[input_context, input_answer],
40
+ outputs=output_question
41
+ )
42
+
43
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ sentencepiece
3
+ torch