avsolatorio's picture
Add wbgtopic
96070b5
from transformers import pipeline
from tqdm.auto import tqdm
import pandas as pd
from transformers import AutoTokenizer
import nltk
# Download the nltk data if not present
nltk.download('punkt_tab')
nltk.download('punkt')
class WBGDocTopic:
"""
A class to handle document topic suggestion using multiple pre-trained text classification models.
This class loads a set of text classification models from Hugging Face's model hub and
provides a method to suggest topics for input documents based on the aggregated classification
results from all the models.
Attributes:
-----------
classifiers : dict
A dictionary mapping model names to corresponding classification pipelines. It holds
instances of Hugging Face's `pipeline` used for text classification.
Methods:
--------
__init__(classifiers: dict = None)
Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default
set of classifiers by calling `load_classifiers`.
load_classifiers()
Loads a predefined set of document topic classifiers into the `classifiers` dictionary.
It uses `tqdm` to display progress as the classifiers are loaded.
suggest_topics(input_docs: str | list[str]) -> list
Suggests topics for the given document or list of documents. It runs each document
through all classifiers, averages their scores, and returns a list of dictionaries where each
dictionary contains the mean and standard deviation of the topic scores per document.
Parameters:
-----------
input_docs : str or list of str
A single document or a list of documents for which to suggest topics.
Returns:
--------
list
A list of dictionaries, where each dictionary represents the suggested topics for
each document, along with the mean and standard deviation of the topic classification scores.
"""
def __init__(self, classifiers: dict = None, device: str = None):
self.classifiers = classifiers or {}
self.device = device
if classifiers is None:
self.load_classifiers()
def load_classifiers(self):
num_evals = 5
num_train = 5
tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03")
for i in tqdm(range(num_evals)):
for j in tqdm(range(num_train)):
if i == j:
continue
model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)
self.classifiers[model_name] = classifier
def suggest_topics(self, input_docs: str | list[str]):
if isinstance(input_docs, str):
input_docs = [input_docs]
doc_outs = {i: [] for i in range(len(input_docs))}
topics = []
for _, classifier in self.classifiers.items():
for doc_idx, doc in enumerate(classifier(input_docs)):
doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label"))
for doc_idx, outs in doc_outs.items():
all_scores = pd.concat(outs, axis=1)
mean_probs = all_scores.mean(axis=1).sort_values(ascending=False)
std_probs = all_scores.std(axis=1).loc[mean_probs.index]
output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs})
output["doc_idx"] = doc_idx
output.reset_index(inplace=True)
topics.append(output.to_dict(orient="records"))
return topics