import gradio as gr from transformers import BertForQuestionAnswering model = BertForQuestionAnswering.from_pretrained("bert-base-uncased") def get_prediction(context, question): inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device) outputs = model(**inputs) answer_start = torch.argmax(outputs[0]) answer_end = torch.argmax(outputs[1]) + 1 answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) return answer def normalize_text(s): """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.""" import string, re def remove_articles(text): regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def exact_match(prediction, truth): return bool(normalize_text(prediction) == normalize_text(truth)) def compute_f1(prediction, truth): pred_tokens = normalize_text(prediction).split() truth_tokens = normalize_text(truth).split() # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise if len(pred_tokens) == 0 or len(truth_tokens) == 0: return int(pred_tokens == truth_tokens) common_tokens = set(pred_tokens) & set(truth_tokens) # if there are no common tokens then f1 = 0 if len(common_tokens) == 0: return 0 prec = len(common_tokens) / len(pred_tokens) rec = len(common_tokens) / len(truth_tokens) return round(2 * (prec * rec) / (prec + rec), 2) def question_answer(context, question): prediction = get_prediction(context,question) return prediction def greet(texts): question = texts[:len(texts)] answer = texts[len(texts):] for question, answer in texts: question_answer(context, question) return texts iface = gr.Interface(fn=greet, inputs="text", outputs="text") iface.launch()