Spaces:
Sleeping
Sleeping
# 使用gradio开发QA的可视化demo | |
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering, BigBirdForQuestionAnswering, BigBirdConfig, PreTrainedModel, BigBirdTokenizer | |
import torch | |
from torch import nn | |
from transformers.models.big_bird.modeling_big_bird import BigBirdOutput, BigBirdIntermediate | |
class BigBirdNullHead(nn.Module): | |
"""Head for question answering tasks.""" | |
def __init__(self, config): | |
super().__init__() | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.intermediate = BigBirdIntermediate(config) | |
self.output = BigBirdOutput(config) | |
self.qa_outputs = nn.Linear(config.hidden_size, 2) | |
def forward(self, encoder_output): | |
hidden_states = self.dropout(encoder_output) | |
hidden_states = self.intermediate(hidden_states) | |
hidden_states = self.output(hidden_states, encoder_output) | |
logits = self.qa_outputs(hidden_states) | |
return logits | |
model_path = './checkpoint-epoch-best' | |
class BigBirdForQuestionAnsweringWithNull(PreTrainedModel): | |
def __init__(self, config, model_id): | |
super().__init__(config) | |
self.bertqa = BigBirdForQuestionAnswering.from_pretrained(model_id, | |
config=self.config, add_pooling_layer=True) | |
self.null_classifier = BigBirdNullHead(self.bertqa.config) | |
self.contrastive_mlp = nn.Sequential( | |
nn.Linear(self.bertqa.config.hidden_size, self.bertqa.config.hidden_size), | |
) | |
def forward(self, **kwargs): | |
if self.training: | |
null_labels = kwargs['is_impossible'] | |
del kwargs['is_impossible'] | |
outputs = self.bertqa(**kwargs) | |
pooler_output = outputs.pooler_output | |
null_logits = self.null_classifier(pooler_output) | |
loss_fct = nn.CrossEntropyLoss() | |
null_loss = loss_fct(null_logits, null_labels) | |
outputs.loss = outputs.loss + null_loss | |
return outputs.to_tuple() | |
else: | |
outputs = self.bertqa(**kwargs) | |
pooler_output = outputs.pooler_output | |
null_logits = self.null_classifier(pooler_output) | |
return (outputs.start_logits, outputs.end_logits, null_logits) | |
model_id = 'vasudevgupta/bigbird-roberta-natural-questions' | |
config = BigBirdConfig.from_pretrained(model_id) | |
model = BigBirdForQuestionAnsweringWithNull(config, model_id) | |
# model.to('cuda') | |
model.eval() | |
model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备 | |
tokenizer = BigBirdTokenizer.from_pretrained(model_path) | |
def main(question, context): | |
# 编码输入 | |
text = question + " [SEP] " + context | |
inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt") | |
# inputs.to('cuda') | |
# 预测答案 | |
outputs = model(**inputs) | |
start_scores = outputs[0] | |
end_scores = outputs[1] | |
null_scores = outputs[2] | |
# 解码答案 | |
is_impossible = null_scores.argmax().item() | |
if is_impossible: | |
return "No Answer" | |
else: | |
answer_start = torch.argmax(start_scores) | |
answer_end = torch.argmax(end_scores) + 1 | |
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) | |
return answer | |
with gr.Blocks() as demo: | |
gr.Markdown("""# Question Answerer!""") | |
with gr.Row(): | |
with gr.Column(): | |
# options = gr.inputs.Radio(["vasudevgupta/bigbird-roberta-natural-questions", "vasudevgupta/bigbird-roberta-natural-questions"], label="Model") | |
text1 = gr.Textbox( | |
label="Question", | |
lines=1, | |
value="Who does Cristiano Ronaldo play for?", | |
) | |
text2 = gr.Textbox( | |
label="Context", | |
lines=3, | |
value="Cristiano Ronaldo is a player for Manchester United", | |
) | |
output = gr.Textbox() | |
b1 = gr.Button("Ask Question!") | |
b1.click(main, inputs=[text1, text2], outputs=output) | |
# gr.Markdown("""#### powered by [Tassle](https://bit.ly/3LXMklV)""") | |
if __name__ == "__main__": | |
demo.launch(share=True) | |