Pierre Chapuis
simplify enhancer code
badbac0 unverified
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import torch
from PIL import Image
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
from esrgan_model import UpscalerESRGAN
@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
esrgan: Path
class ESRGANUpscaler(MultiUpscaler):
def __init__(
self,
checkpoints: ESRGANUpscalerCheckpoints,
device: torch.device,
dtype: torch.dtype,
) -> None:
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
self.esrgan.to(device=device, dtype=dtype)
def to(self, device: torch.device, dtype: torch.dtype):
self.esrgan.to(device=device, dtype=dtype)
self.sd = self.sd.to(device=device, dtype=dtype)
self.device = device
self.dtype = dtype
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
image = self.esrgan.upscale_with_tiling(image)
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)