Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,463 Bytes
55ed985 146eff7 55ed985 4811e40 55ed985 146eff7 55ed985 f8d7009 2a08301 10c708b 2a08301 10c708b f8d7009 2a08301 f8d7009 55ed985 f8d7009 55ed985 f8d7009 55ed985 146eff7 55ed985 f8d7009 55ed985 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import logging
import os
from typing import Union
import spaces
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from asset3d_gen.data.utils import get_images_from_grid
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = [
"ImageStableSR",
"ImageRealESRGAN",
]
class ImageStableSR:
def __init__(
self,
model_path: str = "stabilityai/stable-diffusion-x4-upscaler",
device="cuda",
) -> None:
from diffusers import StableDiffusionUpscalePipeline
self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device)
self.up_pipeline_x4.set_progress_bar_config(disable=True)
# self.up_pipeline_x4.enable_model_cpu_offload()
@spaces.GPU
def __call__(
self,
image: Union[Image.Image, np.ndarray],
prompt: str = "",
infer_step: int = 20,
) -> Image.Image:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
with torch.no_grad():
upscaled_image = self.up_pipeline_x4(
image=image,
prompt=[prompt],
num_inference_steps=infer_step,
).images[0]
return upscaled_image
class ImageRealESRGAN:
def __init__(self, outscale: int, model_path: str = None) -> None:
# monkey patch to support torchvision>=0.16
import torchvision
from packaging import version
if version.parse(torchvision.__version__) > version.parse("0.16"):
import sys
import types
import torchvision.transforms.functional as TF
functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor")
functional_tensor.rgb_to_grayscale = TF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
self.outscale = outscale
self.model_path = model_path
self.upsampler = None
def _lazy_init(self):
if self.upsampler is None:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from huggingface_hub import snapshot_download
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
model_path = self.model_path
if model_path is None:
suffix = "super_resolution"
model_path = snapshot_download(
repo_id="xinjjj/RoboAssetGen", allow_patterns=f"{suffix}/*"
)
model_path = os.path.join(model_path, suffix, "RealESRGAN_x4plus.pth")
self.upsampler = RealESRGANer(
scale=4,
model_path=model_path,
model=model,
pre_pad=0,
half=True,
)
@spaces.GPU
def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image:
self._lazy_init()
if isinstance(image, Image.Image):
image = np.array(image)
with torch.no_grad():
output, _ = self.upsampler.enhance(image, outscale=self.outscale)
return Image.fromarray(output)
if __name__ == "__main__":
color_path = "outputs/texture_mesh_gen/multi_view/color_sample0.png"
# Use RealESRGAN_x4plus for x4 (512->2048) image super resolution.
# model_path = "/horizon-bucket/robot_lab/users/xinjie.wang/weights/super_resolution/RealESRGAN_x4plus.pth" # noqa
super_model = ImageRealESRGAN(outscale=4)
multiviews = get_images_from_grid(color_path, img_size=512)
multiviews = [super_model(img.convert("RGB")) for img in multiviews]
for idx, img in enumerate(multiviews):
img.save(f"sr{idx}.png")
# # Use stable diffusion for x4 (512->2048) image super resolution.
# super_model = ImageStableSR()
# multiviews = get_images_from_grid(color_path, img_size=512)
# multiviews = [super_model(img) for img in multiviews]
# for idx, img in enumerate(multiviews):
# img.save(f"sr_stable{idx}.png")
|