CikeyQi's picture
Upload 964 files (#1)
346533a
raw
history blame
5.62 kB
import copy
from argparse import ArgumentError, ArgumentParser
from contextvars import ContextVar
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path
from typing import (
IO,
Any,
Awaitable,
Callable,
Dict,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
)
from pil_utils import BuildImage
from pydantic import BaseModel, ValidationError
from .exception import (
ArgModelMismatch,
ArgParserExit,
ImageNumberMismatch,
OpenImageFailed,
ParserExit,
TextNumberMismatch,
TextOrNameNotEnough,
)
from .utils import is_coroutine_callable, random_image, random_text, run_sync
class UserInfo(BaseModel):
name: str = ""
gender: Literal["male", "female", "unknown"] = "unknown"
class MemeArgsModel(BaseModel):
user_infos: List[UserInfo] = []
ArgsModel = TypeVar("ArgsModel", bound=MemeArgsModel)
MemeFunction = Union[
Callable[[List[BuildImage], List[str], ArgsModel], BytesIO],
Callable[[List[BuildImage], List[str], ArgsModel], Awaitable[BytesIO]],
]
parser_message: ContextVar[str] = ContextVar("parser_message")
class MemeArgsParser(ArgumentParser):
"""`shell_like` 命令参数解析器,解析出错时不会退出程序。
用法:
用法与 `argparse.ArgumentParser` 相同,
参考文档: [argparse](https://docs.python.org/3/library/argparse.html)
"""
def _print_message(self, message: str, file: Optional[IO[str]] = None):
if (msg := parser_message.get(None)) is not None:
parser_message.set(msg + message)
else:
super()._print_message(message, file)
def exit(self, status: int = 0, message: Optional[str] = None):
if message:
self._print_message(message)
raise ParserExit(status=status, error_message=parser_message.get(None))
@dataclass
class MemeArgsType:
parser: MemeArgsParser
model: Type[MemeArgsModel]
instances: List[MemeArgsModel] = field(default_factory=list)
@dataclass
class MemeParamsType:
min_images: int = 0
max_images: int = 0
min_texts: int = 0
max_texts: int = 0
default_texts: List[str] = field(default_factory=list)
args_type: Optional[MemeArgsType] = None
@dataclass
class Meme:
key: str
function: MemeFunction
params_type: MemeParamsType
keywords: List[str] = field(default_factory=list)
patterns: List[str] = field(default_factory=list)
async def __call__(
self,
*,
images: Union[List[str], List[Path], List[bytes], List[BytesIO]] = [],
texts: List[str] = [],
args: Dict[str, Any] = {},
) -> BytesIO:
if not (
self.params_type.min_images <= len(images) <= self.params_type.max_images
):
raise ImageNumberMismatch(
self.key, self.params_type.min_images, self.params_type.max_images
)
if not (self.params_type.min_texts <= len(texts) <= self.params_type.max_texts):
raise TextNumberMismatch(
self.key, self.params_type.min_texts, self.params_type.max_texts
)
if args_type := self.params_type.args_type:
args_model = args_type.model
else:
args_model = MemeArgsModel
try:
model = args_model.parse_obj(args)
except ValidationError as e:
raise ArgModelMismatch(self.key, str(e))
imgs: List[BuildImage] = []
try:
for image in images:
if isinstance(image, bytes):
image = BytesIO(image)
imgs.append(BuildImage.open(image))
except Exception as e:
raise OpenImageFailed(str(e))
values = {"images": imgs, "texts": texts, "args": model}
if is_coroutine_callable(self.function):
return await cast(Callable[..., Awaitable[BytesIO]], self.function)(
**values
)
else:
return await run_sync(cast(Callable[..., BytesIO], self.function))(**values)
def parse_args(self, args: List[str] = []) -> Dict[str, Any]:
parser = (
copy.deepcopy(self.params_type.args_type.parser)
if self.params_type.args_type
else MemeArgsParser()
)
parser.add_argument("texts", nargs="*", default=[])
t = parser_message.set("")
try:
return vars(parser.parse_args(args))
except ArgumentError as e:
raise ArgParserExit(self.key, str(e))
except ParserExit as e:
raise ArgParserExit(self.key, e.error_message)
finally:
parser_message.reset(t)
async def generate_preview(self, *, args: Dict[str, Any] = {}) -> BytesIO:
default_images = [random_image() for _ in range(self.params_type.min_images)]
default_texts = (
self.params_type.default_texts.copy()
if (
self.params_type.min_texts
<= len(self.params_type.default_texts)
<= self.params_type.max_texts
)
else [random_text() for _ in range(self.params_type.min_texts)]
)
async def _generate_preview(images: List[BytesIO], texts: List[str]):
try:
return await self.__call__(images=images, texts=texts, args=args)
except TextOrNameNotEnough:
texts.append(random_text())
return await _generate_preview(images, texts)
return await _generate_preview(default_images, default_texts)