gistillery / src /ml.py
Benjamin Bossan
Refactor ml model handling
126a4c6
raw
history blame
6.06 kB
import abc
from typing import Any
import logging
import re
import httpx
from base import JobInput
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class Processor(abc.ABC):
def get_name(self) -> str:
return self.__class__.__name__
def __call__(self, job: JobInput) -> str:
_id = job.id
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
result = self.process(job)
logger.info(f"Finished processing input (id={_id[:8]})")
return result
@abc.abstractmethod
def process(self, input: JobInput) -> str:
raise NotImplementedError
@abc.abstractmethod
def match(self, input: JobInput) -> bool:
raise NotImplementedError
class Summarizer(abc.ABC):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def __call__(self, x: str) -> str:
raise NotImplementedError
class Tagger(abc.ABC):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
raise NotImplementedError
def get_name(self) -> str:
raise NotImplementedError
@abc.abstractmethod
def __call__(self, x: str) -> list[str]:
raise NotImplementedError
class MlRegistry:
def __init__(self) -> None:
self.processors: list[Processor] = []
self.summerizer: Summarizer | None = None
self.tagger: Tagger | None = None
self.model = None
self.tokenizer = None
def register_processor(self, processor: Processor) -> None:
self.processors.append(processor)
def register_summarizer(self, summarizer: Summarizer) -> None:
self.summerizer = summarizer
def register_tagger(self, tagger: Tagger) -> None:
self.tagger = tagger
def get_processor(self, input: JobInput) -> Processor:
assert self.processors
for processor in self.processors:
if processor.match(input):
return processor
return RawTextProcessor()
def get_summarizer(self) -> Summarizer:
assert self.summerizer
return self.summerizer
def get_tagger(self) -> Tagger:
assert self.tagger
return self.tagger
class HfTransformersSummarizer(Summarizer):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
self.model_name = model_name
self.model = model
self.tokenizer = tokenizer
self.generation_config = generation_config
self.template = "Summarize the text below in two sentences:\n\n{}"
def __call__(self, x: str) -> str:
text = self.template.format(x)
inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
assert isinstance(output, str)
return output
def get_name(self) -> str:
return f"{self.__class__.__name__}({self.model_name})"
class HfTransformersTagger(Tagger):
def __init__(self, model_name: str, model: Any, tokenizer: Any, generation_config: Any) -> None:
self.model_name = model_name
self.model = model
self.tokenizer = tokenizer
self.generation_config = generation_config
self.template = (
"Create a list of tags for the text below. The tags should be high level "
"and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general"
)
def _extract_tags(self, text: str) -> list[str]:
tags = set()
for tag in text.split():
if tag.startswith("#"):
tags.add(tag.lower())
return sorted(tags)
def __call__(self, x: str) -> list[str]:
text = self.template.format(x)
inputs = self.tokenizer(text, return_tensors="pt")
outputs = self.model.generate(**inputs, generation_config=self.generation_config)
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
tags = self._extract_tags(output)
return tags
def get_name(self) -> str:
return f"{self.__class__.__name__}({self.model_name})"
class RawTextProcessor(Processor):
def match(self, input: JobInput) -> bool:
return True
def process(self, input: JobInput) -> str:
return input.content
class DefaultUrlProcessor(Processor):
def __init__(self) -> None:
self.client = httpx.Client()
self.regex = re.compile(r"(https?://[^\s]+)")
self.url = None
self.template = "{url}\n\n{content}"
def match(self, input: JobInput) -> bool:
urls = list(self.regex.findall(input.content))
if len(urls) == 1:
self.url = urls[0]
return True
return False
def process(self, input: JobInput) -> str:
"""Get content of website and return it as string"""
assert isinstance(self.url, str)
text = self.client.get(self.url).text
assert isinstance(text, str)
text = self.template.format(url=self.url, content=text)
return text
# class ProcessorRegistry:
# def __init__(self) -> None:
# self.registry: list[Processor] = []
# self.default_registry: list[Processor] = []
# self.set_default_processors()
# def set_default_processors(self) -> None:
# self.default_registry.extend([PlainUrlProcessor(), RawProcessor()])
# def register(self, processor: Processor) -> None:
# self.registry.append(processor)
# def dispatch(self, input: JobInput) -> Processor:
# for processor in self.registry + self.default_registry:
# if processor.match(input):
# return processor
# # should never be requires, but eh
# return RawProcessor()