Spaces:
Runtime error
Runtime error
import abc | |
from typing import Any | |
import logging | |
import re | |
import httpx | |
from base import JobInput | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class Processor(abc.ABC): | |
def get_name(self) -> str: | |
return self.__class__.__name__ | |
def __call__(self, job: JobInput) -> str: | |
_id = job.id | |
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") | |
result = self.process(job) | |
logger.info(f"Finished processing input (id={_id[:8]})") | |
return result | |
def process(self, input: JobInput) -> str: | |
raise NotImplementedError | |
def match(self, input: JobInput) -> bool: | |
raise NotImplementedError | |
class Summarizer(abc.ABC): | |
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: | |
raise NotImplementedError | |
def get_name(self) -> str: | |
raise NotImplementedError | |
def __call__(self, x: str) -> str: | |
raise NotImplementedError | |
class Tagger(abc.ABC): | |
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: | |
raise NotImplementedError | |
def get_name(self) -> str: | |
raise NotImplementedError | |
def __call__(self, x: str) -> list[str]: | |
raise NotImplementedError | |
class MlRegistry: | |
def __init__(self) -> None: | |
self.processors: list[Processor] = [] | |
self.summerizer: Summarizer | None = None | |
self.tagger: Tagger | None = None | |
self.model = None | |
self.tokenizer = None | |
def register_processor(self, processor: Processor) -> None: | |
self.processors.append(processor) | |
def register_summarizer(self, summarizer: Summarizer) -> None: | |
self.summerizer = summarizer | |
def register_tagger(self, tagger: Tagger) -> None: | |
self.tagger = tagger | |
def get_processor(self, input: JobInput) -> Processor: | |
assert self.processors | |
for processor in self.processors: | |
if processor.match(input): | |
return processor | |
return RawTextProcessor() | |
def get_summarizer(self) -> Summarizer: | |
assert self.summerizer | |
return self.summerizer | |
def get_tagger(self) -> Tagger: | |
assert self.tagger | |
return self.tagger | |
class HfTransformersSummarizer(Summarizer): | |
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: | |
self.model_name = model_name | |
self.model = model | |
self.tokenizer = tokenizer | |
self.generation_config = generation_config | |
self.template = "Summarize the text below in two sentences:\n\n{}" | |
def __call__(self, x: str) -> str: | |
text = self.template.format(x) | |
inputs = self.tokenizer(text, return_tensors="pt") | |
outputs = self.model.generate(**inputs, generation_config=self.generation_config) | |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
assert isinstance(output, str) | |
return output | |
def get_name(self) -> str: | |
return f"{self.__class__.__name__}({self.model_name})" | |
class HfTransformersTagger(Tagger): | |
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None: | |
self.model_name = model_name | |
self.model = model | |
self.tokenizer = tokenizer | |
self.generation_config = generation_config | |
self.template = ( | |
"Create a list of tags for the text below. The tags should be high level " | |
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general" | |
) | |
def _extract_tags(self, text: str) -> list[str]: | |
tags = set() | |
for tag in text.split(): | |
if tag.startswith("#"): | |
tags.add(tag.lower()) | |
return sorted(tags) | |
def __call__(self, x: str) -> list[str]: | |
text = self.template.format(x) | |
inputs = self.tokenizer(text, return_tensors="pt") | |
outputs = self.model.generate(**inputs, generation_config=self.generation_config) | |
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
tags = self._extract_tags(output) | |
return tags | |
def get_name(self) -> str: | |
return f"{self.__class__.__name__}({self.model_name})" | |
class RawTextProcessor(Processor): | |
def match(self, input: JobInput) -> bool: | |
return True | |
def process(self, input: JobInput) -> str: | |
return input.content | |
class DefaultUrlProcessor(Processor): | |
def __init__(self) -> None: | |
self.client = httpx.Client() | |
self.regex = re.compile(r"(https?://[^\s]+)") | |
self.url = None | |
self.template = "{url}\n\n{content}" | |
def match(self, input: JobInput) -> bool: | |
urls = list(self.regex.findall(input.content)) | |
if len(urls) == 1: | |
self.url = urls[0] | |
return True | |
return False | |
def process(self, input: JobInput) -> str: | |
"""Get content of website and return it as string""" | |
assert isinstance(self.url, str) | |
text = self.client.get(self.url).text | |
assert isinstance(text, str) | |
text = self.template.format(url=self.url, content=text) | |
return text | |
# class ProcessorRegistry: | |
# def __init__(self) -> None: | |
# self.registry: list[Processor] = [] | |
# self.default_registry: list[Processor] = [] | |
# self.set_default_processors() | |
# def set_default_processors(self) -> None: | |
# self.default_registry.extend([PlainUrlProcessor(), RawProcessor()]) | |
# def register(self, processor: Processor) -> None: | |
# self.registry.append(processor) | |
# def dispatch(self, input: JobInput) -> Processor: | |
# for processor in self.registry + self.default_registry: | |
# if processor.match(input): | |
# return processor | |
# # should never be requires, but eh | |
# return RawProcessor() | |