File size: 3,724 Bytes
96070b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from transformers import pipeline
from tqdm.auto import tqdm
import pandas as pd
from transformers import AutoTokenizer
import nltk

# Download the nltk data if not present
nltk.download('punkt_tab')
nltk.download('punkt')


class WBGDocTopic:
    """
    A class to handle document topic suggestion using multiple pre-trained text classification models.

    This class loads a set of text classification models from Hugging Face's model hub and
    provides a method to suggest topics for input documents based on the aggregated classification
    results from all the models.

    Attributes:
    -----------
    classifiers : dict
        A dictionary mapping model names to corresponding classification pipelines. It holds
        instances of Hugging Face's `pipeline` used for text classification.

    Methods:
    --------
    __init__(classifiers: dict = None)
        Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default
        set of classifiers by calling `load_classifiers`.

    load_classifiers()
        Loads a predefined set of document topic classifiers into the `classifiers` dictionary.
        It uses `tqdm` to display progress as the classifiers are loaded.

    suggest_topics(input_docs: str | list[str]) -> list
        Suggests topics for the given document or list of documents. It runs each document
        through all classifiers, averages their scores, and returns a list of dictionaries where each
        dictionary contains the mean and standard deviation of the topic scores per document.

        Parameters:
        -----------
        input_docs : str or list of str
            A single document or a list of documents for which to suggest topics.

        Returns:
        --------
        list
            A list of dictionaries, where each dictionary represents the suggested topics for
            each document, along with the mean and standard deviation of the topic classification scores.
    """

    def __init__(self, classifiers: dict = None, device: str = None):
        self.classifiers = classifiers or {}
        self.device = device

        if classifiers is None:
            self.load_classifiers()

    def load_classifiers(self):
        num_evals = 5
        num_train = 5

        tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03")

        for i in tqdm(range(num_evals)):
            for j in tqdm(range(num_train)):
                if i == j:
                    continue

                model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
                classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)

                self.classifiers[model_name] = classifier

    def suggest_topics(self, input_docs: str | list[str]):
        if isinstance(input_docs, str):
            input_docs = [input_docs]

        doc_outs = {i: [] for i in range(len(input_docs))}
        topics = []

        for _, classifier in self.classifiers.items():
            for doc_idx, doc in enumerate(classifier(input_docs)):
                doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label"))

        for doc_idx, outs in doc_outs.items():
            all_scores = pd.concat(outs, axis=1)
            mean_probs = all_scores.mean(axis=1).sort_values(ascending=False)
            std_probs = all_scores.std(axis=1).loc[mean_probs.index]
            output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs})

            output["doc_idx"] = doc_idx
            output.reset_index(inplace=True)

            topics.append(output.to_dict(orient="records"))

        return topics