|
--- |
|
license: mit |
|
--- |
|
|
|
# The WBG Doc Topic Container |
|
|
|
|
|
```Python |
|
from transformers import pipeline |
|
from tqdm.auto import tqdm |
|
import pandas as pd |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class WBGDocTopic: |
|
""" |
|
A class to handle document topic suggestion using multiple pre-trained text classification models. |
|
|
|
This class loads a set of text classification models from Hugging Face's model hub and |
|
provides a method to suggest topics for input documents based on the aggregated classification |
|
results from all the models. |
|
|
|
Attributes: |
|
----------- |
|
classifiers : dict |
|
A dictionary mapping model names to corresponding classification pipelines. It holds |
|
instances of Hugging Face's `pipeline` used for text classification. |
|
|
|
Methods: |
|
-------- |
|
__init__(classifiers: dict = None) |
|
Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default |
|
set of classifiers by calling `load_classifiers`. |
|
|
|
load_classifiers() |
|
Loads a predefined set of document topic classifiers into the `classifiers` dictionary. |
|
It uses `tqdm` to display progress as the classifiers are loaded. |
|
|
|
suggest_topics(input_docs: str | list[str]) -> list |
|
Suggests topics for the given document or list of documents. It runs each document |
|
through all classifiers, averages their scores, and returns a list of dictionaries where each |
|
dictionary contains the mean and standard deviation of the topic scores per document. |
|
|
|
Parameters: |
|
----------- |
|
input_docs : str or list of str |
|
A single document or a list of documents for which to suggest topics. |
|
|
|
Returns: |
|
-------- |
|
list |
|
A list of dictionaries, where each dictionary represents the suggested topics for |
|
each document, along with the mean and standard deviation of the topic classification scores. |
|
""" |
|
|
|
def __init__(self, classifiers: dict = None, device: str = None |
|
self.classifiers = classifiers or {} |
|
self.device = device |
|
|
|
if classifiers is None: |
|
self.load_classifiers() |
|
|
|
def load_classifiers(self): |
|
num_evals = 5 |
|
num_train = 5 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03") |
|
|
|
for i in tqdm(range(num_evals)): |
|
for j in tqdm(range(num_train)): |
|
if i == j: |
|
continue |
|
|
|
model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}" |
|
classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device) |
|
|
|
self.classifiers[model_name] = classifier |
|
|
|
def suggest_topics(self, input_docs: str | list[str]): |
|
if isinstance(input_docs, str): |
|
input_docs = [input_docs] |
|
|
|
doc_outs = {i: [] for i in range(len(input_docs))} |
|
topics = [] |
|
|
|
for _, classifier in self.classifiers.items(): |
|
for doc_idx, doc in enumerate(classifier(input_docs)): |
|
doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label")) |
|
|
|
for doc_idx, outs in doc_outs.items(): |
|
all_scores = pd.concat(outs, axis=1) |
|
mean_probs = all_scores.mean(axis=1).sort_values(ascending=False) |
|
std_probs = all_scores.std(axis=1).loc[mean_probs.index] |
|
output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs}) |
|
|
|
output["doc_idx"] = doc_idx |
|
output.reset_index(inplace=True) |
|
|
|
topics.append(output.to_dict(orient="records")) |
|
|
|
return topics |
|
``` |
|
|
|
|
|
# Using the WBGDocTopic model |
|
|
|
```Python |
|
import nltk |
|
|
|
# Download the nltk data if not present |
|
nltk.download('punkt_tab') |
|
nltk.download('punkt') |
|
|
|
from collections import Counter |
|
|
|
# Load the sent_tokenize method for quick sentence extraction |
|
from nltk import sent_tokenize |
|
|
|
# Process the input |
|
sample_text = """A growing literature attributes gender inequality in labor market outcomes in part to the reduction in female labor supply after childbirth, the child penalty. However, if social norms constrain married women’s activities outside the home, then marriage can independently reduce employment, even in the absence childbearing. Given the correlation in timing between childbirth and marriage, conventional estimates of child penalties will conflate these two effects. The paper studies the marriage penalty in South Asia, a context featuring conservative gender norms and low female labor force participation. The study introduces a split-sample, pseudo-panel approach that allows for the separation of marriage and child penalties even in the absence of individual-level panel data. Marriage reduces women’s labor force participation in South Asia by 12 percentage points, whereas the marginal penalty of childbearing is small. Consistent with the central roles of both opportunity costs and social norms, the marriage penalty is smaller among cohorts with higher education and less conservative gender attitudes.""" |
|
sents = sent_tokenize(inp) |
|
|
|
# Create the instance which will load the models. |
|
# Set the device to "cuda" if you want to use a GPU. |
|
dtopic_model = WBGDocTopic(device=None) |
|
|
|
# Infer the topics and scores |
|
outs = dtopic_model.suggest_topics(sents) |
|
outs |
|
# [[{'label': 'Gender', |
|
# 'score_mean': 0.8776359841227531, |
|
# 'score_std': 0.13074095501538094, |
|
# 'doc_idx': 0}, |
|
# {'label': 'Labor Markets', |
|
# 'score_mean': 0.20742715448141097, |
|
# 'score_std': 0.20991565414467345, |
|
# 'doc_idx': 0}, |
|
# {'label': "Girls' Education", |
|
# 'score_mean': 0.19432228063233198, |
|
# 'score_std': 0.21148874269682794, |
|
# 'doc_idx': 0}, ...]] |
|
|
|
# Get the distribution of the abstract's highly relevant topics per sentence. |
|
# Use a currently arbitrary threshold of 0.1. |
|
Counter([o["label"] for out in outs for o in out if (o["score_mean"] > 0.1 and o["score_mean"] > o["score_std"])]).most_common() |
|
``` |