Logan Zoellner commited on
Commit
bfc97b7
·
1 Parent(s): 73630d0

inital commit

Browse files
Files changed (2) hide show
  1. app.py +213 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #%%capture
4
+ #!git lfs install
5
+ #!git clone https://huggingface.co/Cene655/ImagenT5-3B
6
+
7
+ #%%capture
8
+ #!pip install git+https://github.com/cene555/Imagen-pytorch.git
9
+ #!pip install git+https://github.com/openai/CLIP.git
10
+
11
+ #%%capture
12
+ #!git clone https://github.com/xinntao/Real-ESRGAN.git
13
+
14
+ #%cd Real-ESRGAN
15
+
16
+ #%%capture
17
+ #!pip install basicsr
18
+ # facexlib and gfpgan are for face enhancement
19
+ #!pip install facexlib
20
+ #!pip install gfpgan
21
+
22
+ #%%capture
23
+ #!pip install -r requirements.txt
24
+ #!python setup.py develop
25
+ #!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P experiments/pretrained_models
26
+
27
+ #Imports
28
+
29
+ from PIL import Image
30
+ from IPython.display import display
31
+ import torch as th
32
+ from imagen_pytorch.model_creation import create_model_and_diffusion as create_model_and_diffusion_dalle2
33
+ from imagen_pytorch.model_creation import model_and_diffusion_defaults as model_and_diffusion_defaults_dalle2
34
+ from transformers import AutoTokenizer
35
+ import cv2
36
+
37
+ import glob
38
+ import os
39
+ from basicsr.archs.rrdbnet_arch import RRDBNet
40
+ from realesrgan import RealESRGANer
41
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
42
+ from gfpgan import GFPGANer
43
+
44
+ has_cuda = th.cuda.is_available()
45
+ device = th.device('cpu' if not has_cuda else 'cuda')
46
+
47
+ Setting Up
48
+
49
+ def model_fn(x_t, ts, **kwargs):
50
+ guidance_scale = 5
51
+ half = x_t[: len(x_t) // 2]
52
+ combined = th.cat([half, half], dim=0)
53
+ model_out = model(combined, ts, **kwargs)
54
+ eps, rest = model_out[:, :3], model_out[:, 3:]
55
+ cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
56
+ half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
57
+ eps = th.cat([half_eps, half_eps], dim=0)
58
+ return th.cat([eps, rest], dim=1)
59
+
60
+ def show_images(batch: th.Tensor):
61
+ """ Display a batch of images inline."""
62
+ scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
63
+ reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
64
+ display(Image.fromarray(reshaped.numpy()))
65
+
66
+ def get_numpy_img(img):
67
+ scaled = ((img + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
68
+ reshaped = scaled.permute(2, 0, 3, 1).reshape([img.shape[2], -1, 3])
69
+ return cv2.cvtColor(reshaped.numpy(), cv2.COLOR_BGR2RGB)
70
+
71
+ def _fix_path(path):
72
+ d = th.load(path)
73
+ checkpoint = {}
74
+ for key in d.keys():
75
+ checkpoint[key.replace('module.','')] = d[key]
76
+ return checkpoint
77
+
78
+ options = model_and_diffusion_defaults_dalle2()
79
+ options['use_fp16'] = False
80
+ options['diffusion_steps'] = 200
81
+ options['num_res_blocks'] = 3
82
+ options['t5_name'] = 't5-3b'
83
+ options['cache_text_emb'] = True
84
+ model, diffusion = create_model_and_diffusion_dalle2(**options)
85
+
86
+ model.eval()
87
+
88
+ #if has_cuda:
89
+ # model.convert_to_fp16()
90
+
91
+ model.to(device)
92
+
93
+ model.load_state_dict(_fix_path('/content/ImagenT5-3B/model.pt'))
94
+ print('total base parameters', sum(x.numel() for x in model.parameters()))
95
+
96
+ total base parameters 1550556742
97
+
98
+ num_params = sum(param.numel() for param in model.parameters())
99
+ num_params
100
+
101
+ 1550556742
102
+
103
+ realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
104
+ num_block=23, num_grow_ch=32, scale=4)
105
+
106
+ netscale = 4
107
+
108
+ upsampler = RealESRGANer(
109
+ scale=netscale,
110
+ model_path='/content/Real-ESRGAN/experiments/pretrained_models/RealESRGAN_x4plus.pth',
111
+ model=realesrgan_model,
112
+ tile=0,
113
+ tile_pad=10,
114
+ pre_pad=0,
115
+ half=True
116
+ )
117
+
118
+ face_enhancer = GFPGANer(
119
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
120
+ upscale=4,
121
+ arch='clean',
122
+ channel_multiplier=2,
123
+ bg_upsampler=upsampler
124
+ )
125
+
126
+ tokenizer = AutoTokenizer.from_pretrained(options['t5_name'])
127
+
128
+ /usr/local/lib/python3.7/dist-packages/transformers/models/t5/tokenization_t5_fast.py:161: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
129
+ For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
130
+ - Be aware that you SHOULD NOT rely on t5-3b automatically truncating your input to 512 when padding/encoding.
131
+ - If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
132
+ - To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
133
+ FutureWarning,
134
+
135
+ #@title What do you want to generate?
136
+
137
+ prompt = 'A photo of cat'#@param {type:"string"}
138
+
139
+ def gen_img(prompt):
140
+
141
+ text_encoding = tokenizer(
142
+ prompt,
143
+ max_length=128,
144
+ padding="max_length",
145
+ truncation=True,
146
+ return_attention_mask=True,
147
+ add_special_tokens=True,
148
+ return_tensors="pt"
149
+ )
150
+
151
+ uncond_text_encoding = tokenizer(
152
+ '',
153
+ max_length=128,
154
+ padding="max_length",
155
+ truncation=True,
156
+ return_attention_mask=True,
157
+ add_special_tokens=True,
158
+ return_tensors="pt"
159
+ )
160
+
161
+ import numpy as np
162
+ batch_size = 4
163
+ cond_tokens = th.from_numpy(np.array([text_encoding['input_ids'][0].numpy() for i in range(batch_size)]))
164
+ uncond_tokens = th.from_numpy(np.array([uncond_text_encoding['input_ids'][0].numpy() for i in range(batch_size)]))
165
+ cond_attention_mask = th.from_numpy(np.array([text_encoding['attention_mask'][0].numpy() for i in range(batch_size)]))
166
+ uncond_attention_mask = th.from_numpy(np.array([uncond_text_encoding['attention_mask'][0].numpy() for i in range(batch_size)]))
167
+ model_kwargs = {}
168
+ model_kwargs["tokens"] = th.cat((cond_tokens,
169
+ uncond_tokens)).to(device)
170
+ model_kwargs["mask"] = th.cat((cond_attention_mask,
171
+ uncond_attention_mask)).to(device)
172
+
173
+ Generation
174
+
175
+ model.del_cache()
176
+ sample = diffusion.p_sample_loop(
177
+ model_fn,
178
+ (batch_size * 2, 3, 64, 64),
179
+ clip_denoised=True,
180
+ model_kwargs=model_kwargs,
181
+ device='cuda',
182
+ progress=True,
183
+ )[:batch_size]
184
+ model.del_cache()
185
+
186
+ return sample
187
+
188
+ demo = gr.Blocks()
189
+
190
+ with demo:
191
+ gr.Markdown("<h1><center>cene555/Imagen-pytorch</center></h1>")
192
+ gr.Markdown(
193
+ "<div>github repo <a href='https://github.com/cene555/Imagen-pytorch/blob/main/images/2.jpg'>here</a></div>"
194
+ "<div>hf model <a href='https://huggingface.co/Cene655/ImagenT5-3B/tree/main'>here</a></div>"
195
+ )
196
+
197
+ with gr.Row():
198
+ b0 = gr.Button("generate")
199
+ b1 = gr.Button("upscale")
200
+
201
+ with gr.Row():
202
+ desc = gr.Textbox(label="description",placeholder="an impressionist painting of a white vase")
203
+
204
+ with gr.Row():
205
+ intermediate_image = gr.Image(label="portrait",type="filepath", shape=(256,256))
206
+ output_image = gr.Image(label="portrait",type="filepath", shape=(256,256))
207
+
208
+ b0.click(gen_img,inputs=[desc],outputs=[intermediate_image])
209
+ b1.click(upscale_img, inputs=[ intermediate_image], outputs=output_image)
210
+ #examples=examples
211
+
212
+ demo.launch(enable_queue=True, debug=True)
213
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://huggingface.co/Cene655/ImagenT5-3B
2
+ git+https://github.com/cene555/Imagen-pytorch.git
3
+ git+https://github.com/openai/CLIP.git
4
+ git+https://github.com/xinntao/Real-ESRGAN.git
5
+ basicsr
6
+ facexlib
7
+ gfpgan
8
+