Spaces:
Build error
Build error
File size: 5,454 Bytes
546a9ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from model.base_model import SummModel
from model.single_doc import TextRankModel
from typing import List, Union
from nltk import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
class QueryBasedSummModel(SummModel):
is_query_based = True
def __init__(
self,
trained_domain: str = None,
max_input_length: int = None,
max_output_length: int = None,
model_backend: SummModel = TextRankModel,
retrieval_ratio: float = 0.5,
preprocess: bool = True,
**kwargs,
):
super(QueryBasedSummModel, self).__init__(
trained_domain=trained_domain,
max_input_length=max_input_length,
max_output_length=max_output_length,
)
self.model = model_backend(**kwargs)
self.retrieval_ratio = retrieval_ratio
self.preprocess = preprocess
def _retrieve(self, instance: List[str], query: List[str], n_best) -> List[str]:
raise NotImplementedError()
def summarize(
self,
corpus: Union[List[str], List[List[str]]],
queries: List[str] = None,
) -> List[str]:
self.assert_summ_input_type(corpus, queries)
retrieval_output = [] # List[str]
for instance, query in zip(corpus, queries):
if isinstance(instance, str):
is_dialogue = False
instance = sent_tokenize(instance)
else:
is_dialogue = True
query = [query]
# instance & query now are List[str] for sure
if self.preprocess:
preprocessor = Preprocessor()
instance = preprocessor.preprocess(instance)
query = preprocessor.preprocess(query)
n_best = max(int(len(instance) * self.retrieval_ratio), 1)
top_n_sent = self._retrieve(instance, query, n_best)
if not is_dialogue:
top_n_sent = " ".join(top_n_sent) # str
retrieval_output.append(top_n_sent)
summaries = self.model.summarize(
retrieval_output
) # List[str] or List[List[str]]
return summaries
def generate_specific_description(self):
is_neural = self.model.is_neural & self.is_neural
is_extractive = self.model.is_extractive | self.is_extractive
model_name = "Pipeline with retriever: {}, summarizer: {}".format(
self.model_name, self.model.model_name
)
extractive_abstractive = "extractive" if is_extractive else "abstractive"
neural = "neural" if is_neural else "non-neural"
basic_description = (
f"{model_name} is a "
f"{'query-based' if self.is_query_based else ''} "
f"{extractive_abstractive}, {neural} model for summarization."
)
return basic_description
@classmethod
def assert_summ_input_type(cls, corpus, query):
if query is None:
raise TypeError(
"Query-based summarization models summarize instances of query-text pairs, however, query is missing."
)
if not isinstance(query, list):
raise TypeError(
"Query-based single-document summarization requires query of `List[str]`."
)
if not all([isinstance(q, str) for q in query]):
raise TypeError(
"Query-based single-document summarization requires query of `List[str]`."
)
@classmethod
def generate_basic_description(cls) -> str:
basic_description = (
"QueryBasedSummModel performs query-based summarization. Given a query-text pair,"
"the model will first extract the most relevant sentences in articles or turns in "
"dialogues, then use the single document summarization model to generate the summary"
)
return basic_description
@classmethod
def show_capability(cls):
basic_description = cls.generate_basic_description()
more_details = (
"A query-based summarization model."
" Allows for custom model backend selection at initialization."
" Retrieve relevant turns and then summarize the retrieved turns\n"
"Strengths: \n - Allows for control of backend model.\n"
"Weaknesses: \n - Heavily depends on the performance of both retriever and summarizer.\n"
)
print(f"{basic_description}\n{'#' * 20}\n{more_details}")
class Preprocessor:
def __init__(self, remove_stopwords=True, lower_case=True, stem=False):
self.sw = stopwords.words("english")
self.stemmer = PorterStemmer()
self.remove_stopwords = remove_stopwords
self.lower_case = lower_case
self.stem = stem
def preprocess(self, corpus: List[str]) -> List[str]:
if self.lower_case:
corpus = [sent.lower() for sent in corpus]
tokenized_corpus = [word_tokenize(sent) for sent in corpus]
if self.remove_stopwords:
tokenized_corpus = [
[word for word in sent if word not in self.sw]
for sent in tokenized_corpus
]
if self.stem:
tokenized_corpus = [
[self.stemmer.stem(word) for word in sent] for sent in tokenized_corpus
]
return [" ".join(sent) for sent in tokenized_corpus]
|