|
|
|
|
|
|
|
import os
|
|
os.system("git clone https://huggingface.co/Cene655/ImagenT5-3B")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from PIL import Image
|
|
from IPython.display import display
|
|
import torch as th
|
|
from imagen_pytorch.model_creation import create_model_and_diffusion as create_model_and_diffusion_dalle2
|
|
from imagen_pytorch.model_creation import model_and_diffusion_defaults as model_and_diffusion_defaults_dalle2
|
|
from transformers import AutoTokenizer
|
|
import cv2
|
|
|
|
import glob
|
|
import os
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
from realesrgan import RealESRGANer
|
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
from gfpgan import GFPGANer
|
|
|
|
has_cuda = th.cuda.is_available()
|
|
device = th.device('cpu' if not has_cuda else 'cuda')
|
|
|
|
Setting Up
|
|
|
|
def model_fn(x_t, ts, **kwargs):
|
|
guidance_scale = 5
|
|
half = x_t[: len(x_t) // 2]
|
|
combined = th.cat([half, half], dim=0)
|
|
model_out = model(combined, ts, **kwargs)
|
|
eps, rest = model_out[:, :3], model_out[:, 3:]
|
|
cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
|
|
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
|
eps = th.cat([half_eps, half_eps], dim=0)
|
|
return th.cat([eps, rest], dim=1)
|
|
|
|
def show_images(batch: th.Tensor):
|
|
""" Display a batch of images inline."""
|
|
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
|
|
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
|
|
display(Image.fromarray(reshaped.numpy()))
|
|
|
|
def get_numpy_img(img):
|
|
scaled = ((img + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
|
|
reshaped = scaled.permute(2, 0, 3, 1).reshape([img.shape[2], -1, 3])
|
|
return cv2.cvtColor(reshaped.numpy(), cv2.COLOR_BGR2RGB)
|
|
|
|
def _fix_path(path):
|
|
d = th.load(path)
|
|
checkpoint = {}
|
|
for key in d.keys():
|
|
checkpoint[key.replace('module.','')] = d[key]
|
|
return checkpoint
|
|
|
|
options = model_and_diffusion_defaults_dalle2()
|
|
options['use_fp16'] = False
|
|
options['diffusion_steps'] = 200
|
|
options['num_res_blocks'] = 3
|
|
options['t5_name'] = 't5-3b'
|
|
options['cache_text_emb'] = True
|
|
model, diffusion = create_model_and_diffusion_dalle2(**options)
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
model.to(device)
|
|
|
|
model.load_state_dict(_fix_path('/content/ImagenT5-3B/model.pt'))
|
|
print('total base parameters', sum(x.numel() for x in model.parameters()))
|
|
|
|
total base parameters 1550556742
|
|
|
|
num_params = sum(param.numel() for param in model.parameters())
|
|
num_params
|
|
|
|
1550556742
|
|
|
|
realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
|
num_block=23, num_grow_ch=32, scale=4)
|
|
|
|
netscale = 4
|
|
|
|
upsampler = RealESRGANer(
|
|
scale=netscale,
|
|
model_path='/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth',
|
|
model=realesrgan_model,
|
|
tile=0,
|
|
tile_pad=10,
|
|
pre_pad=0,
|
|
half=True
|
|
)
|
|
|
|
face_enhancer = GFPGANer(
|
|
model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
|
upscale=4,
|
|
arch='clean',
|
|
channel_multiplier=2,
|
|
bg_upsampler=upsampler
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(options['t5_name'])
|
|
|
|
/usr/local/lib/python3.7/dist-packages/transformers/models/t5/tokenization_t5_fast.py:161: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
|
|
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
|
|
- Be aware that you SHOULD NOT rely on t5-3b automatically truncating your input to 512 when padding/encoding.
|
|
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
|
|
- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
|
|
FutureWarning,
|
|
|
|
|
|
|
|
prompt = 'A photo of cat'
|
|
|
|
def gen_img(prompt):
|
|
|
|
text_encoding = tokenizer(
|
|
prompt,
|
|
max_length=128,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_attention_mask=True,
|
|
add_special_tokens=True,
|
|
return_tensors="pt"
|
|
)
|
|
|
|
uncond_text_encoding = tokenizer(
|
|
'',
|
|
max_length=128,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_attention_mask=True,
|
|
add_special_tokens=True,
|
|
return_tensors="pt"
|
|
)
|
|
|
|
import numpy as np
|
|
batch_size = 4
|
|
cond_tokens = th.from_numpy(np.array([text_encoding['input_ids'][0].numpy() for i in range(batch_size)]))
|
|
uncond_tokens = th.from_numpy(np.array([uncond_text_encoding['input_ids'][0].numpy() for i in range(batch_size)]))
|
|
cond_attention_mask = th.from_numpy(np.array([text_encoding['attention_mask'][0].numpy() for i in range(batch_size)]))
|
|
uncond_attention_mask = th.from_numpy(np.array([uncond_text_encoding['attention_mask'][0].numpy() for i in range(batch_size)]))
|
|
model_kwargs = {}
|
|
model_kwargs["tokens"] = th.cat((cond_tokens,
|
|
uncond_tokens)).to(device)
|
|
model_kwargs["mask"] = th.cat((cond_attention_mask,
|
|
uncond_attention_mask)).to(device)
|
|
|
|
Generation
|
|
|
|
model.del_cache()
|
|
sample = diffusion.p_sample_loop(
|
|
model_fn,
|
|
(batch_size * 2, 3, 64, 64),
|
|
clip_denoised=True,
|
|
model_kwargs=model_kwargs,
|
|
device='cuda',
|
|
progress=True,
|
|
)[:batch_size]
|
|
model.del_cache()
|
|
|
|
return sample
|
|
|
|
demo = gr.Blocks()
|
|
|
|
with demo:
|
|
gr.Markdown("<h1><center>cene555/Imagen-pytorch</center></h1>")
|
|
gr.Markdown(
|
|
"<div>github repo <a href='https://github.com/cene555/Imagen-pytorch/blob/main/images/2.jpg'>here</a></div>"
|
|
"<div>hf model <a href='https://huggingface.co/Cene655/ImagenT5-3B/tree/main'>here</a></div>"
|
|
)
|
|
|
|
with gr.Row():
|
|
b0 = gr.Button("generate")
|
|
b1 = gr.Button("upscale")
|
|
|
|
with gr.Row():
|
|
desc = gr.Textbox(label="description",placeholder="an impressionist painting of a white vase")
|
|
|
|
with gr.Row():
|
|
intermediate_image = gr.Image(label="portrait",type="filepath", shape=(256,256))
|
|
output_image = gr.Image(label="portrait",type="filepath", shape=(256,256))
|
|
|
|
b0.click(gen_img,inputs=[desc],outputs=[intermediate_image])
|
|
b1.click(upscale_img, inputs=[ intermediate_image], outputs=output_image)
|
|
|
|
|
|
demo.launch(enable_queue=True, debug=True)
|
|
|
|
|