from typing import List from PIL import Image from .common import CommonUpscaler, OfflineUpscaler from .waifu2x import Waifu2xUpscaler from .esrgan import ESRGANUpscaler from .esrgan_pytorch import ESRGANUpscalerPytorch UPSCALERS = { 'waifu2x': Waifu2xUpscaler, 'esrgan': ESRGANUpscaler, '4xultrasharp': ESRGANUpscalerPytorch, } upscaler_cache = {} def get_upscaler(key: str, *args, **kwargs) -> CommonUpscaler: if key not in UPSCALERS: raise ValueError(f'Could not find upscaler for: "{key}". Choose from the following: %s' % ','.join(UPSCALERS)) if not upscaler_cache.get(key): upscaler = UPSCALERS[key] upscaler_cache[key] = upscaler(*args, **kwargs) return upscaler_cache[key] async def prepare(upscaler_key: str): upscaler = get_upscaler(upscaler_key) if isinstance(upscaler, OfflineUpscaler): await upscaler.download() async def dispatch(upscaler_key: str, image_batch: List[Image.Image], upscale_ratio: int, device: str = 'cpu') -> List[Image.Image]: if upscale_ratio == 1: return image_batch upscaler = get_upscaler(upscaler_key) if isinstance(upscaler, OfflineUpscaler): await upscaler.load(device) return await upscaler.upscale(image_batch, upscale_ratio)