Spaces:
Runtime error
Runtime error
File size: 6,056 Bytes
a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 a240da9 126a4c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import abc
from typing import Any
import logging
import re
import httpx
from base import JobInput
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class Processor(abc.ABC):
def get_name(self) -> str:
return self.__class__.__name__
def __call__(self, job: JobInput) -> str:
_id = job.id
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
result = self.process(job)
logger.info(f"Finished processing input (id={_id[:8]})")
return result
@abc.abstractmethod
def process(self, input: JobInput) -> str:
raise NotImplementedError
@abc.abstractmethod
def match(self, input: JobInput) -> bool:
raise NotImplementedError
class Summarizer(abc.ABC):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def __call__(self, x: str) -> str:
raise NotImplementedError
class Tagger(abc.ABC):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def __call__(self, x: str) -> list[str]:
raise NotImplementedError
class MlRegistry:
def __init__(self) -> None:
self.processors: list[Processor] = []
self.summerizer: Summarizer | None = None
self.tagger: Tagger | None = None
self.model = None
self.tokenizer = None
def register_processor(self, processor: Processor) -> None:
self.processors.append(processor)
def register_summarizer(self, summarizer: Summarizer) -> None:
self.summerizer = summarizer
def register_tagger(self, tagger: Tagger) -> None:
self.tagger = tagger
def get_processor(self, input: JobInput) -> Processor:
assert self.processors
for processor in self.processors:
if processor.match(input):
return processor
return RawTextProcessor()
def get_summarizer(self) -> Summarizer:
assert self.summerizer
return self.summerizer
def get_tagger(self) -> Tagger:
assert self.tagger
return self.tagger
class HfTransformersSummarizer(Summarizer):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
self.model_name = model_name
self.model = model
self.tokenizer = tokenizer
self.generation_config = generation_config
self.template = "Summarize the text below in two sentences:\n\n{}"
def __call__(self, x: str) -> str:
text = self.template.format(x)
inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
assert isinstance(output, str)
return output
def get_name(self) -> str:
return f"{self.__class__.__name__}({self.model_name})"
class HfTransformersTagger(Tagger):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
self.model_name = model_name
self.model = model
self.tokenizer = tokenizer
self.generation_config = generation_config
self.template = (
"Create a list of tags for the text below. The tags should be high level "
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
)
def _extract_tags(self, text: str) -> list[str]:
tags = set()
for tag in text.split():
if tag.startswith("#"):
tags.add(tag.lower())
return sorted(tags)
def __call__(self, x: str) -> list[str]:
text = self.template.format(x)
inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
tags = self._extract_tags(output)
return tags
def get_name(self) -> str:
return f"{self.__class__.__name__}({self.model_name})"
class RawTextProcessor(Processor):
def match(self, input: JobInput) -> bool:
return True
def process(self, input: JobInput) -> str:
return input.content
class DefaultUrlProcessor(Processor):
def __init__(self) -> None:
self.client = httpx.Client()
self.regex = re.compile(r"(https?://[^\s]+)")
self.url = None
self.template = "{url}\n\n{content}"
def match(self, input: JobInput) -> bool:
urls = list(self.regex.findall(input.content))
if len(urls) == 1:
self.url = urls[0]
return True
return False
def process(self, input: JobInput) -> str:
"""Get content of website and return it as string"""
assert isinstance(self.url, str)
text = self.client.get(self.url).text
assert isinstance(text, str)
text = self.template.format(url=self.url, content=text)
return text
# class ProcessorRegistry:
# def __init__(self) -> None:
# self.registry: list[Processor] = []
# self.default_registry: list[Processor] = []
# self.set_default_processors()
# def set_default_processors(self) -> None:
# self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
# def register(self, processor: Processor) -> None:
# self.registry.append(processor)
# def dispatch(self, input: JobInput) -> Processor:
# for processor in self.registry + self.default_registry:
# if processor.match(input):
# return processor
# # should never be requires, but eh
# return RawProcessor()
|