Spaces:
Sleeping
Sleeping
File size: 4,265 Bytes
606291c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os.path as osp
import streamlit as st
from spacy import displacy
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForTokenClassification, \
AutoModelForSeq2SeqLM, pipeline
model_paths = {
'Text Summarization': './data/pretrained_models/roberta2roberta_L-24_bbc',
'Question Answering': './data/pretrained_models/roberta-base-squad2',
'Entity Recognition': './data/pretrained_models/xlm-roberta-large-finetuned-conll03-english'
}
@st.experimental_memo
def get_tokenizer(operation):
path = model_paths[operation]
tokenizer = AutoTokenizer.from_pretrained(path)
return tokenizer
@st.experimental_memo
def get_model(operation):
path = model_paths[operation]
if operation == 'Text Summarization':
model = AutoModelForSeq2SeqLM.from_pretrained(path)
elif operation == 'Question Answering':
model = AutoModelForQuestionAnswering.from_pretrained(path)
elif operation == 'Entity Recognition':
model = AutoModelForTokenClassification.from_pretrained(path)
return model
def get_predictions(input_text, operations, input_question=None):
all_outputs = []
for operation in operations:
tokenizer = get_tokenizer(operation)
model = get_model(operation)
if operation == 'Text Summarization':
summary = get_summary(input_text, tokenizer, model)
all_outputs.append(('Predicted Summary', summary))
if operation == 'Question Answering':
answer = get_answer(input_text, tokenizer, model, input_question)
all_outputs.append(('Predicted Answer', answer))
if operation == 'Entity Recognition':
output = get_entities(input_text, tokenizer, model)
all_outputs.append(('Predicted Entities', output))
return all_outputs
@st.experimental_memo
def get_summary(input_text, _tokenizer, _model):
inputs = _tokenizer(input_text, max_length=512, return_tensors="pt", truncation=True)
summary_ids = _model.generate(inputs["input_ids"], num_beams=3, max_length=50, early_stopping=True)[0]
summary = _tokenizer.decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return summary
@st.experimental_memo
def get_answer(input_text, _tokenizer, _model, input_question):
inputs = _tokenizer(input_question, input_text, max_length=512, return_tensors="pt", truncation=True)
with torch.inference_mode():
answer = _model(**inputs)
answer_start_index = answer.start_logits.argmax()
answer_end_index = answer.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index: answer_end_index + 1]
predicted_answer = _tokenizer.decode(predict_answer_tokens)
return predicted_answer
@st.experimental_memo
def get_entities(input_text, _tokenizer, _model):
classifier = pipeline("ner", model=_model, tokenizer=_tokenizer, aggregation_strategy='simple')
output = classifier(input_text)
output_dict = {'text': input_text}
entities = []
for i in output:
temp = {'start': i['start'], 'end': i['end'], 'label': i['entity_group']}
entities.append(temp)
output_dict['ents'] = entities
display_output = displacy.render(output_dict, style='ent', manual=True)
return display_output
if __name__ == '__main__':
print("hello")
input_text = '''Lord of the Rings book was written by JRR Tolkein. It is one of the most popular fictional book ever written.
A very popular movie series is also based on the same book. Now, people are starting a television series based on the book as well.'''
path = model_paths['Entity Recognition']
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForTokenClassification.from_pretrained(path)
classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='simple')
output = classifier(input_text)
output_dict = {'text': input_text}
entities = []
for i in output:
temp = {'start': i['start'], 'end': i['end'], 'label': i['entity_group']}
entities.append(temp)
output_dict['ents'] = entities
display_output = displacy.render(output_dict, style='ent', manual=True)
print(display_output)
|