Spaces:
Sleeping
Sleeping
import asyncio | |
import hashlib | |
import inspect | |
import math | |
import random | |
import time | |
from dataclasses import dataclass | |
from enum import Enum | |
from functools import partial, wraps | |
from io import BytesIO | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
Callable, | |
Coroutine, | |
List, | |
Literal, | |
Optional, | |
Protocol, | |
Tuple, | |
TypeVar, | |
) | |
import httpx | |
from PIL.Image import Image as IMG | |
from pil_utils import BuildImage, Text2Image | |
from pil_utils.types import ColorType, FontStyle, FontWeight | |
from typing_extensions import ParamSpec | |
from .config import meme_config | |
from .exception import MemeGeneratorException | |
if TYPE_CHECKING: | |
from .meme import Meme | |
P = ParamSpec("P") | |
R = TypeVar("R") | |
def run_sync(call: Callable[P, R]) -> Callable[P, Coroutine[None, None, R]]: | |
"""一个用于包装 sync function 为 async function 的装饰器 | |
参数: | |
call: 被装饰的同步函数 | |
""" | |
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | |
loop = asyncio.get_running_loop() | |
pfunc = partial(call, *args, **kwargs) | |
result = await loop.run_in_executor(None, pfunc) | |
return result | |
return _wrapper | |
def is_coroutine_callable(call: Callable[..., Any]) -> bool: | |
"""检查 call 是否是一个 callable 协程函数""" | |
if inspect.isroutine(call): | |
return inspect.iscoroutinefunction(call) | |
if inspect.isclass(call): | |
return False | |
func_ = getattr(call, "__call__", None) | |
return inspect.iscoroutinefunction(func_) | |
def save_gif(frames: List[IMG], duration: float) -> BytesIO: | |
output = BytesIO() | |
frames[0].save( | |
output, | |
format="GIF", | |
save_all=True, | |
append_images=frames[1:], | |
duration=duration * 1000, | |
loop=0, | |
disposal=2, | |
optimize=False, | |
) | |
# 没有超出最大大小,直接返回 | |
nbytes = output.getbuffer().nbytes | |
if nbytes <= meme_config.gif.gif_max_size * 10**6: | |
return output | |
# 超出最大大小,帧数超出最大帧数时,缩减帧数 | |
n_frames = len(frames) | |
gif_max_frames = meme_config.gif.gif_max_frames | |
if n_frames > gif_max_frames: | |
index = range(n_frames) | |
ratio = n_frames / gif_max_frames | |
index = (int(i * ratio) for i in range(gif_max_frames)) | |
new_duration = duration * ratio | |
new_frames = [frames[i] for i in index] | |
return save_gif(new_frames, new_duration) | |
# 超出最大大小,帧数没有超出最大帧数时,缩小尺寸 | |
new_frames = [ | |
frame.resize((int(frame.width * 0.9), int(frame.height * 0.9))) | |
for frame in frames | |
] | |
return save_gif(new_frames, duration) | |
class Maker(Protocol): | |
def __call__(self, img: BuildImage) -> BuildImage: | |
... | |
class GifMaker(Protocol): | |
def __call__(self, i: int) -> Maker: | |
... | |
def get_avg_duration(image: IMG) -> float: | |
if not getattr(image, "is_animated", False): | |
return 0 | |
total_duration = 0 | |
for i in range(image.n_frames): | |
image.seek(i) | |
total_duration += image.info["duration"] | |
return total_duration / image.n_frames | |
def split_gif(image: IMG) -> List[IMG]: | |
frames: List[IMG] = [] | |
update_mode = "full" | |
for i in range(image.n_frames): | |
image.seek(i) | |
if image.tile: # type: ignore | |
update_region = image.tile[0][1][2:] # type: ignore | |
if update_region != image.size: | |
update_mode = "partial" | |
break | |
last_frame: Optional[IMG] = None | |
for i in range(image.n_frames): | |
image.seek(i) | |
frame = image.copy() | |
if update_mode == "partial" and last_frame: | |
frame = last_frame.copy().paste(frame) | |
frames.append(frame) | |
image.seek(0) | |
if image.info.__contains__("transparency"): | |
frames[0].info["transparency"] = image.info["transparency"] | |
return frames | |
def make_jpg_or_gif( | |
img: BuildImage, func: Maker, keep_transparency: bool = False | |
) -> BytesIO: | |
""" | |
制作静图或者动图 | |
:params | |
* ``img``: 输入图片 | |
* ``func``: 图片处理函数,输入img,返回处理后的图片 | |
* ``keep_transparency``: 传入gif时,是否保留该gif的透明度 | |
""" | |
image = img.image | |
if not getattr(image, "is_animated", False): | |
return func(img).save_jpg() | |
else: | |
frames = split_gif(image) | |
duration = get_avg_duration(image) / 1000 | |
frames = [func(BuildImage(frame)).image for frame in frames] | |
if keep_transparency: | |
image.seek(0) | |
if image.info.__contains__("transparency"): | |
frames[0].info["transparency"] = image.info["transparency"] | |
return save_gif(frames, duration) | |
def make_png_or_gif( | |
img: BuildImage, func: Maker, keep_transparency: bool = False | |
) -> BytesIO: | |
""" | |
制作静图或者动图 | |
:params | |
* ``img``: 输入图片 | |
* ``func``: 图片处理函数,输入img,返回处理后的图片 | |
* ``keep_transparency``: 传入gif时,是否保留该gif的透明度 | |
""" | |
image = img.image | |
if not getattr(image, "is_animated", False): | |
return func(img).save_png() | |
else: | |
frames = split_gif(image) | |
duration = get_avg_duration(image) / 1000 | |
frames = [func(BuildImage(frame)).image for frame in frames] | |
if keep_transparency: | |
image.seek(0) | |
if image.info.__contains__("transparency"): | |
frames[0].info["transparency"] = image.info["transparency"] | |
return save_gif(frames, duration) | |
class FrameAlignPolicy(Enum): | |
""" | |
要叠加的gif长度大于基准gif时,是否延长基准gif长度以对齐两个gif | |
""" | |
no_extend = 0 | |
"""不延长""" | |
extend_first = 1 | |
"""延长第一帧""" | |
extend_last = 2 | |
"""延长最后一帧""" | |
extend_loop = 3 | |
"""以循环方式延长""" | |
def make_gif_or_combined_gif( | |
img: BuildImage, | |
maker: GifMaker, | |
frame_num: int, | |
duration: float, | |
frame_align: FrameAlignPolicy = FrameAlignPolicy.no_extend, | |
input_based: bool = False, | |
keep_transparency: bool = False, | |
) -> BytesIO: | |
""" | |
使用静图或动图制作gif | |
:params | |
* ``img``: 输入图片,如头像 | |
* ``maker``: 图片处理函数生成,传入第几帧,返回对应的图片处理函数 | |
* ``frame_num``: 目标gif的帧数 | |
* ``duration``: 相邻帧之间的时间间隔,单位为秒 | |
* ``frame_align``: 要叠加的gif长度大于基准gif时,gif长度对齐方式 | |
* ``input_based``: 是否以输入gif为基准合成gif,默认为`False`,即以目标gif为基准 | |
* ``keep_transparency``: 传入gif时,是否保留该gif的透明度 | |
""" | |
image = img.image | |
if not getattr(image, "is_animated", False): | |
return save_gif([maker(i)(img).image for i in range(frame_num)], duration) | |
frame_num_in = image.n_frames | |
duration_in = get_avg_duration(image) / 1000 | |
total_duration_in = frame_num_in * duration_in | |
total_duration = frame_num * duration | |
if input_based: | |
frame_num_base = frame_num_in | |
frame_num_fit = frame_num | |
duration_base = duration_in | |
duration_fit = duration | |
total_duration_base = total_duration_in | |
total_duration_fit = total_duration | |
else: | |
frame_num_base = frame_num | |
frame_num_fit = frame_num_in | |
duration_base = duration | |
duration_fit = duration_in | |
total_duration_base = total_duration | |
total_duration_fit = total_duration_in | |
frame_idxs: List[int] = list(range(frame_num_base)) | |
diff_duration = total_duration_fit - total_duration_base | |
diff_num = int(diff_duration / duration_base) | |
if diff_duration >= duration_base: | |
if frame_align == FrameAlignPolicy.extend_first: | |
frame_idxs = [0] * diff_num + frame_idxs | |
elif frame_align == FrameAlignPolicy.extend_last: | |
frame_idxs += [frame_num_base - 1] * diff_num | |
elif frame_align == FrameAlignPolicy.extend_loop: | |
frame_num_total = frame_num_base | |
# 重复基准gif,直到两个gif总时长之差在1个间隔以内,或总帧数超出最大帧数 | |
while frame_num_total + frame_num_base <= meme_config.gif.gif_max_frames: | |
frame_num_total += frame_num_base | |
frame_idxs += list(range(frame_num_base)) | |
multiple = round(frame_num_total * duration_base / total_duration_fit) | |
if ( | |
math.fabs( | |
total_duration_fit * multiple - frame_num_total * duration_base | |
) | |
<= duration_base | |
): | |
break | |
frames: List[IMG] = [] | |
frame_idx_fit = 0 | |
time_start = 0 | |
for i, idx in enumerate(frame_idxs): | |
while frame_idx_fit < frame_num_fit: | |
if ( | |
frame_idx_fit * duration_fit | |
<= i * duration_base - time_start | |
< (frame_idx_fit + 1) * duration_fit | |
): | |
if input_based: | |
idx_in = idx | |
idx_maker = frame_idx_fit | |
else: | |
idx_in = frame_idx_fit | |
idx_maker = idx | |
func = maker(idx_maker) | |
image.seek(idx_in) | |
frames.append(func(BuildImage(image.copy())).image) | |
break | |
else: | |
frame_idx_fit += 1 | |
if frame_idx_fit >= frame_num_fit: | |
frame_idx_fit = 0 | |
time_start += total_duration_fit | |
if keep_transparency: | |
image.seek(0) | |
if image.info.__contains__("transparency"): | |
frames[0].info["transparency"] = image.info["transparency"] | |
return save_gif(frames, duration) | |
async def translate(text: str, lang_from: str = "auto", lang_to: str = "zh") -> str: | |
appid = meme_config.translate.baidu_trans_appid | |
apikey = meme_config.translate.baidu_trans_apikey | |
if not appid or not apikey: | |
raise MemeGeneratorException( | |
"The `baidu_trans_appid` or `baidu_trans_apikey` is not set." | |
"Please check your config file!" | |
) | |
salt = str(round(time.time() * 1000)) | |
sign_raw = appid + text + salt + apikey | |
sign = hashlib.md5(sign_raw.encode("utf8")).hexdigest() | |
params = { | |
"q": text, | |
"from": lang_from, | |
"to": lang_to, | |
"appid": appid, | |
"salt": salt, | |
"sign": sign, | |
} | |
url = "https://fanyi-api.baidu.com/api/trans/vip/translate" | |
async with httpx.AsyncClient() as client: | |
resp = await client.get(url, params=params) | |
result = resp.json() | |
return result["trans_result"][0]["dst"] | |
def random_text() -> str: | |
return random.choice(["刘一", "陈二", "张三", "李四", "王五", "赵六", "孙七", "周八", "吴九", "郑十"]) | |
def random_image() -> BytesIO: | |
text = random.choice(["😂", "😅", "🤗", "🤤", "🥵", "🥰", "😍", "😭", "😋", "😏"]) | |
return ( | |
BuildImage.new("RGBA", (500, 500), "white") | |
.draw_text((0, 0, 500, 500), text, max_fontsize=400) | |
.save_png() | |
) | |
class TextProperties: | |
fill: ColorType = "black" | |
style: FontStyle = "normal" | |
weight: FontWeight = "normal" | |
stroke_width: int = 0 | |
stroke_fill: Optional[ColorType] = None | |
def default_template(meme: "Meme", number: int) -> str: | |
return f"{number}. {'/'.join(meme.keywords)}" | |
def render_meme_list( | |
meme_list: List[Tuple["Meme", TextProperties]], | |
*, | |
template: Callable[["Meme", int], str] = default_template, | |
order_direction: Literal["row", "column"] = "column", | |
columns: int = 4, | |
column_align: Literal["left", "center", "right"] = "left", | |
item_padding: Tuple[int, int] = (15, 6), | |
image_padding: Tuple[int, int] = (50, 50), | |
bg_color: ColorType = "white", | |
fontsize: int = 30, | |
fontname: str = "", | |
fallback_fonts: List[str] = [], | |
) -> BytesIO: | |
item_images: List[Text2Image] = [] | |
for i, (meme, properties) in enumerate(meme_list, start=1): | |
text = template(meme, i) | |
t2m = Text2Image.from_text( | |
text, | |
fontsize=fontsize, | |
style=properties.style, | |
weight=properties.weight, | |
fill=properties.fill, | |
stroke_width=properties.stroke_width, | |
stroke_fill=properties.stroke_fill, | |
fontname=fontname, | |
fallback_fonts=fallback_fonts, | |
) | |
item_images.append(t2m) | |
char_A = ( | |
Text2Image.from_text( | |
"A", fontsize=fontsize, fontname=fontname, fallback_fonts=fallback_fonts | |
) | |
.lines[0] | |
.chars[0] | |
) | |
num_per_col = math.ceil(len(item_images) / columns) | |
column_images: List[BuildImage] = [] | |
for col in range(columns): | |
if order_direction == "column": | |
images = item_images[col * num_per_col : (col + 1) * num_per_col] | |
else: | |
images = [ | |
item_images[num * columns + col] | |
for num in range((len(item_images) - col - 1) // columns + 1) | |
] | |
img_w = max((t2m.width for t2m in images)) + item_padding[0] * 2 | |
img_h = (char_A.ascent + item_padding[1] * 2) * len(images) + char_A.descent | |
image = BuildImage.new("RGB", (img_w, img_h), bg_color) | |
y = item_padding[1] | |
for t2m in images: | |
if column_align == "left": | |
x = 0 | |
elif column_align == "center": | |
x = (img_w - t2m.width - item_padding[0] * 2) // 2 | |
else: | |
x = img_w - t2m.width - item_padding[0] * 2 | |
t2m.draw_on_image(image.image, (x, y)) | |
y += char_A.ascent + item_padding[1] * 2 | |
column_images.append(image) | |
img_w = sum((img.width for img in column_images)) + image_padding[0] * 2 | |
img_h = max((img.height for img in column_images)) + image_padding[1] * 2 | |
image = BuildImage.new("RGB", (img_w, img_h), bg_color) | |
x, y = image_padding | |
for img in column_images: | |
image.paste(img, (x, y)) | |
x += img.width | |
return image.save_jpg() | |