|
import os |
|
import torch |
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
from torch.nn.functional import interpolate |
|
from omegaconf import OmegaConf |
|
from sgm.util import instantiate_from_config |
|
from huggingface_hub import hf_hub_download |
|
|
|
def get_state_dict(d): |
|
return d.get('state_dict', d) |
|
|
|
def load_state_dict(ckpt_path, location='cpu'): |
|
_, extension = os.path.splitext(ckpt_path) |
|
|
|
|
|
if '/' in ckpt_path: |
|
parts = ckpt_path.split('/') |
|
if len(parts) == 3: |
|
repo_id = f"{parts[0]}/{parts[1]}" |
|
filename = parts[2] |
|
|
|
print(f"Attempting to download from Hugging Face Hub with repo_id: {repo_id} and filename: {filename}") |
|
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
|
print(f"Downloaded file path for {filename}: {ckpt_path}") |
|
else: |
|
raise ValueError("Invalid format for Hugging Face path. Expected format 'username/repo/filename'.") |
|
|
|
|
|
if extension.lower() == ".safetensors": |
|
import safetensors.torch |
|
state_dict = safetensors.torch.load_file(ckpt_path, device=location) |
|
else: |
|
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) |
|
|
|
print(f'Loaded state_dict from [{ckpt_path}]') |
|
return state_dict |
|
|
|
|
|
def create_model(config_path): |
|
config = OmegaConf.load(config_path) |
|
model = instantiate_from_config(config.model).cpu() |
|
print(f'Loaded model config from [{config_path}]') |
|
return model |
|
|
|
|
|
def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False): |
|
config = OmegaConf.load(config_path) |
|
model = instantiate_from_config(config.model).cpu() |
|
print(f'Loaded model config from [{config_path}]') |
|
if config.SDXL_CKPT is not None: |
|
model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False) |
|
if config.SUPIR_CKPT is not None: |
|
model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False) |
|
if SUPIR_sign is not None: |
|
assert SUPIR_sign in ['F', 'Q'] |
|
if SUPIR_sign == 'F': |
|
model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False) |
|
elif SUPIR_sign == 'Q': |
|
model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False) |
|
if load_default_setting: |
|
default_setting = config.default_setting |
|
return model, default_setting |
|
return model |
|
|
|
def load_QF_ckpt(config_path): |
|
config = OmegaConf.load(config_path) |
|
|
|
|
|
if '/' in config.SUPIR_CKPT_F: |
|
parts = config.SUPIR_CKPT_F.split('/') |
|
if len(parts) == 3: |
|
repo_id = f"{parts[0]}/{parts[1]}" |
|
filename = parts[2] |
|
print(f"Attempting to download SUPIR_CKPT_F from repo_id: {repo_id} and filename: {filename}") |
|
ckpt_F_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
print(f"Downloaded SUPIR_CKPT_F to: {ckpt_F_path}") |
|
else: |
|
raise ValueError("Invalid format for SUPIR_CKPT_F. Expected format 'username/repo/filename'.") |
|
else: |
|
ckpt_F_path = config.SUPIR_CKPT_F |
|
|
|
|
|
if '/' in config.SUPIR_CKPT_Q: |
|
parts = config.SUPIR_CKPT_Q.split('/') |
|
if len(parts) == 3: |
|
repo_id = f"{parts[0]}/{parts[1]}" |
|
filename = parts[2] |
|
print(f"Attempting to download SUPIR_CKPT_Q from repo_id: {repo_id} and filename: {filename}") |
|
ckpt_Q_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
print(f"Downloaded SUPIR_CKPT_Q to: {ckpt_Q_path}") |
|
else: |
|
raise ValueError("Invalid format for SUPIR_CKPT_Q. Expected format 'username/repo/filename'.") |
|
else: |
|
ckpt_Q_path = config.SUPIR_CKPT_Q |
|
|
|
|
|
ckpt_F = torch.load(ckpt_F_path, map_location='cpu') |
|
ckpt_Q = torch.load(ckpt_Q_path, map_location='cpu') |
|
return ckpt_Q, ckpt_F |
|
|
|
|
|
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None): |
|
''' |
|
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1] |
|
''' |
|
|
|
w, h = img.size |
|
w *= upsacle |
|
h *= upsacle |
|
w0, h0 = round(w), round(h) |
|
if min(w, h) < min_size: |
|
_upsacle = min_size / min(w, h) |
|
w *= _upsacle |
|
h *= _upsacle |
|
if fix_resize is not None: |
|
_upsacle = fix_resize / min(w, h) |
|
w *= _upsacle |
|
h *= _upsacle |
|
w0, h0 = round(w), round(h) |
|
w = int(np.round(w / 64.0)) * 64 |
|
h = int(np.round(h / 64.0)) * 64 |
|
x = img.resize((w, h), Image.BICUBIC) |
|
x = np.array(x).round().clip(0, 255).astype(np.uint8) |
|
x = x / 255 * 2 - 1 |
|
x = torch.tensor(x, dtype=torch.float32).permute(2, 0, 1) |
|
return x, h0, w0 |
|
|
|
|
|
def Tensor2PIL(x, h0, w0): |
|
''' |
|
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image |
|
''' |
|
x = x.unsqueeze(0) |
|
x = interpolate(x, size=(h0, w0), mode='bicubic') |
|
x = (x.squeeze(0).permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
return Image.fromarray(x) |
|
|
|
|
|
def HWC3(x): |
|
assert x.dtype == np.uint8 |
|
if x.ndim == 2: |
|
x = x[:, :, None] |
|
assert x.ndim == 3 |
|
H, W, C = x.shape |
|
assert C == 1 or C == 3 or C == 4 |
|
if C == 3: |
|
return x |
|
if C == 1: |
|
return np.concatenate([x, x, x], axis=2) |
|
if C == 4: |
|
color = x[:, :, 0:3].astype(np.float32) |
|
alpha = x[:, :, 3:4].astype(np.float32) / 255.0 |
|
y = color * alpha + 255.0 * (1.0 - alpha) |
|
y = y.clip(0, 255).astype(np.uint8) |
|
return y |
|
|
|
|
|
def upscale_image(input_image, upscale, min_size=None, unit_resolution=64): |
|
H, W, C = input_image.shape |
|
H = float(H) |
|
W = float(W) |
|
H *= upscale |
|
W *= upscale |
|
if min_size is not None: |
|
if min(H, W) < min_size: |
|
_upsacle = min_size / min(W, H) |
|
W *= _upsacle |
|
H *= _upsacle |
|
H = int(np.round(H / unit_resolution)) * unit_resolution |
|
W = int(np.round(W / unit_resolution)) * unit_resolution |
|
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA) |
|
img = img.round().clip(0, 255).astype(np.uint8) |
|
return img |
|
|
|
|
|
def fix_resize(input_image, size=512, unit_resolution=64): |
|
H, W, C = input_image.shape |
|
H = float(H) |
|
W = float(W) |
|
upscale = size / min(H, W) |
|
H *= upscale |
|
W *= upscale |
|
H = int(np.round(H / unit_resolution)) * unit_resolution |
|
W = int(np.round(W / unit_resolution)) * unit_resolution |
|
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if upscale > 1 else cv2.INTER_AREA) |
|
img = img.round().clip(0, 255).astype(np.uint8) |
|
return img |
|
|
|
|
|
|
|
def Numpy2Tensor(img): |
|
''' |
|
np.array[H, w, C] [0, 255] -> Tensor[C, H, W], RGB, [-1, 1] |
|
''' |
|
|
|
img = np.array(img) / 255 * 2 - 1 |
|
img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) |
|
return img |
|
|
|
|
|
def Tensor2Numpy(x, h0=None, w0=None): |
|
''' |
|
Tensor[C, H, W], RGB, [-1, 1] -> PIL.Image |
|
''' |
|
if h0 is not None and w0 is not None: |
|
x = x.unsqueeze(0) |
|
x = interpolate(x, size=(h0, w0), mode='bicubic') |
|
x = x.squeeze(0) |
|
x = (x.permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) |
|
return x |
|
|
|
|
|
def convert_dtype(dtype_str): |
|
if dtype_str == 'fp32': |
|
return torch.float32 |
|
elif dtype_str == 'fp16': |
|
return torch.float16 |
|
elif dtype_str == 'bf16': |
|
return torch.bfloat16 |
|
else: |
|
raise NotImplementedError |