import logging import os from typing import Union import spaces import numpy as np import torch from huggingface_hub import snapshot_download from PIL import Image from asset3d_gen.data.utils import get_images_from_grid logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) __all__ = [ "ImageStableSR", "ImageRealESRGAN", ] class ImageStableSR: def __init__( self, model_path: str = "stabilityai/stable-diffusion-x4-upscaler", device="cuda", ) -> None: from diffusers import StableDiffusionUpscalePipeline self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( model_path, torch_dtype=torch.float16, ).to(device) self.up_pipeline_x4.set_progress_bar_config(disable=True) # self.up_pipeline_x4.enable_model_cpu_offload() @spaces.GPU def __call__( self, image: Union[Image.Image, np.ndarray], prompt: str = "", infer_step: int = 20, ) -> Image.Image: if isinstance(image, np.ndarray): image = Image.fromarray(image) image = image.convert("RGB") with torch.no_grad(): upscaled_image = self.up_pipeline_x4( image=image, prompt=[prompt], num_inference_steps=infer_step, ).images[0] return upscaled_image class ImageRealESRGAN: def __init__(self, outscale: int, model_path: str = None) -> None: # monkey patch to support torchvision>=0.16 import torchvision from packaging import version if version.parse(torchvision.__version__) > version.parse("0.16"): import sys import types import torchvision.transforms.functional as TF functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor") functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor self.outscale = outscale self.model_path = model_path self.upsampler = None def _lazy_init(self): if self.upsampler is None: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from huggingface_hub import snapshot_download model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) model_path = self.model_path if model_path is None: suffix = "super_resolution" model_path = snapshot_download( repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*" ) model_path = os.path.join(model_path, suffix, "RealESRGAN_x4plus.pth") self.upsampler = RealESRGANer( scale=4, model_path=model_path, model=model, pre_pad=0, half=True, ) @spaces.GPU def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: self._lazy_init() if isinstance(image, Image.Image): image = np.array(image) with torch.no_grad(): output, _ = self.upsampler.enhance(image, outscale=self.outscale) return Image.fromarray(output) if __name__ == "__main__": color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png" # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution. # model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa super_model = ImageRealESRGAN(outscale=4) multiviews = get_images_from_grid(color_path, img_size=512) multiviews = [super_model(img.convert("RGB")) for img in multiviews] for idx, img in enumerate(multiviews): img.save(f"sr{idx}.png") # # Use stable diffusion for x4 (512->2048) image super resolution. # super_model = ImageStableSR() # multiviews = get_images_from_grid(color_path, img_size=512) # multiviews = [super_model(img) for img in multiviews] # for idx, img in enumerate(multiviews): # img.save(f"sr_stable{idx}.png")