File size: 6,056 Bytes
a240da9
126a4c6
a240da9
 
 
 
 
 
 
 
 
 
 
126a4c6
 
 
 
 
 
 
 
 
 
a240da9
126a4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a240da9
126a4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a240da9
 
 
 
126a4c6
 
 
a240da9
 
 
 
126a4c6
a240da9
 
126a4c6
 
 
 
 
 
 
a240da9
 
 
 
 
 
 
 
 
 
 
 
 
 
126a4c6
 
 
a240da9
 
 
 
126a4c6
a240da9
 
126a4c6
a240da9
 
 
 
 
 
 
126a4c6
a240da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126a4c6
 
 
 
 
a240da9
126a4c6
 
a240da9
126a4c6
 
a240da9
126a4c6
 
 
 
a240da9
126a4c6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import abc
from typing import Any
import logging
import re

import httpx

from base import JobInput

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class Processor(abc.ABC):
    def get_name(self) -> str:
        return self.__class__.__name__

    def __call__(self, job: JobInput) -> str:
        _id = job.id
        logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
        result = self.process(job)
        logger.info(f"Finished processing input (id={_id[:8]})")
        return result

    @abc.abstractmethod
    def process(self, input: JobInput) -> str:
        raise NotImplementedError

    @abc.abstractmethod
    def match(self, input: JobInput) -> bool:
        raise NotImplementedError


class Summarizer(abc.ABC):
    def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
        raise NotImplementedError

    def get_name(self) -> str:
        raise NotImplementedError

    @abc.abstractmethod
    def __call__(self, x: str) -> str:
        raise NotImplementedError


class Tagger(abc.ABC):
    def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
        raise NotImplementedError

    def get_name(self) -> str:
        raise NotImplementedError

    @abc.abstractmethod
    def __call__(self, x: str) -> list[str]:
        raise NotImplementedError


class MlRegistry:
    def __init__(self) -> None:
        self.processors: list[Processor] = []
        self.summerizer: Summarizer | None = None
        self.tagger: Tagger | None = None
        self.model = None
        self.tokenizer = None

    def register_processor(self, processor: Processor) -> None:
        self.processors.append(processor)

    def register_summarizer(self, summarizer: Summarizer) -> None:
        self.summerizer = summarizer

    def register_tagger(self, tagger: Tagger) -> None:
        self.tagger = tagger

    def get_processor(self, input: JobInput) -> Processor:
        assert self.processors
        for processor in self.processors:
            if processor.match(input):
                return processor

        return RawTextProcessor()

    def get_summarizer(self) -> Summarizer:
        assert self.summerizer
        return self.summerizer

    def get_tagger(self) -> Tagger:
        assert self.tagger
        return self.tagger


class HfTransformersSummarizer(Summarizer):
    def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
        self.model_name = model_name
        self.model = model
        self.tokenizer = tokenizer
        self.generation_config = generation_config

        self.template = "Summarize the text below in two sentences:\n\n{}"

    def __call__(self, x: str) -> str:
        text = self.template.format(x)
        inputs = self.tokenizer(text, return_tensors="pt")
        outputs = self.model.generate(**inputs, generation_config=self.generation_config)
        output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        assert isinstance(output, str)
        return output

    def get_name(self) -> str:
        return f"{self.__class__.__name__}({self.model_name})"


class HfTransformersTagger(Tagger):
    def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
        self.model_name = model_name
        self.model = model
        self.tokenizer = tokenizer
        self.generation_config = generation_config

        self.template = (
            "Create a list of tags for the text below. The tags should be high level "
            "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
        )

    def _extract_tags(self, text: str) -> list[str]:
        tags = set()
        for tag in text.split():
            if tag.startswith("#"):
                tags.add(tag.lower())
        return sorted(tags)

    def __call__(self, x: str) -> list[str]:
        text = self.template.format(x)
        inputs = self.tokenizer(text, return_tensors="pt")
        outputs = self.model.generate(**inputs, generation_config=self.generation_config)
        output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        tags = self._extract_tags(output)
        return tags

    def get_name(self) -> str:
        return f"{self.__class__.__name__}({self.model_name})"


class RawTextProcessor(Processor):
    def match(self, input: JobInput) -> bool:
        return True

    def process(self, input: JobInput) -> str:
        return input.content


class DefaultUrlProcessor(Processor):
    def __init__(self) -> None:
        self.client = httpx.Client()
        self.regex = re.compile(r"(https?://[^\s]+)")
        self.url = None
        self.template = "{url}\n\n{content}"

    def match(self, input: JobInput) -> bool:
        urls = list(self.regex.findall(input.content))
        if len(urls) == 1:
            self.url = urls[0]
            return True
        return False

    def process(self, input: JobInput) -> str:
        """Get content of website and return it as string"""
        assert isinstance(self.url, str)
        text = self.client.get(self.url).text
        assert isinstance(text, str)
        text = self.template.format(url=self.url, content=text)
        return text

# class ProcessorRegistry:
#     def __init__(self) -> None:
#         self.registry: list[Processor] = []
#         self.default_registry: list[Processor] = []
#         self.set_default_processors()

#     def set_default_processors(self) -> None:
#         self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])

#     def register(self, processor: Processor) -> None:
#         self.registry.append(processor)

#     def dispatch(self, input: JobInput) -> Processor:
#         for processor in self.registry + self.default_registry:
#             if processor.match(input):
#                 return processor

#         # should never be requires, but eh
#         return RawProcessor()