Spaces:
Running
on
Zero
Running
on
Zero
Delete gradio.ipynb
Browse files- gradio.ipynb +0 -317
gradio.ipynb
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [
|
8 |
-
{
|
9 |
-
"name": "stderr",
|
10 |
-
"output_type": "stream",
|
11 |
-
"text": [
|
12 |
-
"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
-
]
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"name": "stdout",
|
18 |
-
"output_type": "stream",
|
19 |
-
"text": [
|
20 |
-
"[2024-06-28 00:45:26,702] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
|
21 |
-
]
|
22 |
-
}
|
23 |
-
],
|
24 |
-
"source": [
|
25 |
-
"import gradio as gr\n",
|
26 |
-
"import sys\n",
|
27 |
-
"import os \n",
|
28 |
-
"import tqdm\n",
|
29 |
-
"sys.path.append(os.path.abspath(os.path.join(\"\", \"..\")))\n",
|
30 |
-
"import torch\n",
|
31 |
-
"import gc\n",
|
32 |
-
"import warnings\n",
|
33 |
-
"warnings.filterwarnings(\"ignore\")\n",
|
34 |
-
"from PIL import Image\n",
|
35 |
-
"from utils import load_models, save_model_w2w, save_model_for_diffusers\n",
|
36 |
-
"from sampling import sample_weights"
|
37 |
-
]
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"cell_type": "code",
|
41 |
-
"execution_count": 2,
|
42 |
-
"metadata": {},
|
43 |
-
"outputs": [],
|
44 |
-
"source": [
|
45 |
-
"global device\n",
|
46 |
-
"global generator \n",
|
47 |
-
"global unet\n",
|
48 |
-
"global vae \n",
|
49 |
-
"global text_encoder\n",
|
50 |
-
"global tokenizer\n",
|
51 |
-
"global noise_scheduler\n",
|
52 |
-
"device = \"cuda:0\"\n",
|
53 |
-
"generator = torch.Generator(device=device)"
|
54 |
-
]
|
55 |
-
},
|
56 |
-
{
|
57 |
-
"cell_type": "code",
|
58 |
-
"execution_count": 3,
|
59 |
-
"metadata": {},
|
60 |
-
"outputs": [],
|
61 |
-
"source": [
|
62 |
-
"mean = torch.load(\"files/mean.pt\").bfloat16().to(device)\n",
|
63 |
-
"std = torch.load(\"files/std.pt\").bfloat16().to(device)\n",
|
64 |
-
"v = torch.load(\"files/V.pt\").bfloat16().to(device)\n",
|
65 |
-
"proj = torch.load(\"files/proj_1000pc.pt\").bfloat16().to(device)\n",
|
66 |
-
"df = torch.load(\"files/identity_df.pt\")\n",
|
67 |
-
"weight_dimensions = torch.load(\"files/weight_dimensions.pt\")"
|
68 |
-
]
|
69 |
-
},
|
70 |
-
{
|
71 |
-
"cell_type": "code",
|
72 |
-
"execution_count": 4,
|
73 |
-
"metadata": {},
|
74 |
-
"outputs": [
|
75 |
-
{
|
76 |
-
"name": "stderr",
|
77 |
-
"output_type": "stream",
|
78 |
-
"text": [
|
79 |
-
"Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.79it/s]\n"
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"name": "stdout",
|
84 |
-
"output_type": "stream",
|
85 |
-
"text": [
|
86 |
-
"\n"
|
87 |
-
]
|
88 |
-
}
|
89 |
-
],
|
90 |
-
"source": [
|
91 |
-
"unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)"
|
92 |
-
]
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"cell_type": "code",
|
96 |
-
"execution_count": 5,
|
97 |
-
"metadata": {},
|
98 |
-
"outputs": [],
|
99 |
-
"source": [
|
100 |
-
"global network"
|
101 |
-
]
|
102 |
-
},
|
103 |
-
{
|
104 |
-
"cell_type": "code",
|
105 |
-
"execution_count": 6,
|
106 |
-
"metadata": {},
|
107 |
-
"outputs": [],
|
108 |
-
"source": [
|
109 |
-
"def sample_model():\n",
|
110 |
-
" global unet\n",
|
111 |
-
" del unet\n",
|
112 |
-
" global network\n",
|
113 |
-
" unet, _, _, _, _ = load_models(device)\n",
|
114 |
-
" network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)\n",
|
115 |
-
" \n"
|
116 |
-
]
|
117 |
-
},
|
118 |
-
{
|
119 |
-
"cell_type": "code",
|
120 |
-
"execution_count": 7,
|
121 |
-
"metadata": {},
|
122 |
-
"outputs": [],
|
123 |
-
"source": [
|
124 |
-
"@torch.no_grad()\n",
|
125 |
-
"def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):\n",
|
126 |
-
" global device\n",
|
127 |
-
" global generator \n",
|
128 |
-
" global unet\n",
|
129 |
-
" global vae \n",
|
130 |
-
" global text_encoder\n",
|
131 |
-
" global tokenizer\n",
|
132 |
-
" global noise_scheduler\n",
|
133 |
-
" generator = generator.manual_seed(seed)\n",
|
134 |
-
" latents = torch.randn(\n",
|
135 |
-
" (1, unet.in_channels, 512 // 8, 512 // 8),\n",
|
136 |
-
" generator = generator,\n",
|
137 |
-
" device = device\n",
|
138 |
-
" ).bfloat16()\n",
|
139 |
-
" \n",
|
140 |
-
"\n",
|
141 |
-
" text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n",
|
142 |
-
"\n",
|
143 |
-
" text_embeddings = text_encoder(text_input.input_ids.to(device))[0]\n",
|
144 |
-
"\n",
|
145 |
-
" max_length = text_input.input_ids.shape[-1]\n",
|
146 |
-
" uncond_input = tokenizer(\n",
|
147 |
-
" [negative_prompt], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n",
|
148 |
-
" )\n",
|
149 |
-
" uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]\n",
|
150 |
-
" text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n",
|
151 |
-
" noise_scheduler.set_timesteps(ddim_steps) \n",
|
152 |
-
" latents = latents * noise_scheduler.init_noise_sigma\n",
|
153 |
-
" \n",
|
154 |
-
" for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):\n",
|
155 |
-
" latent_model_input = torch.cat([latents] * 2)\n",
|
156 |
-
" latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)\n",
|
157 |
-
" with network:\n",
|
158 |
-
" noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample\n",
|
159 |
-
" #guidance\n",
|
160 |
-
" noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
|
161 |
-
" noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
|
162 |
-
" latents = noise_scheduler.step(noise_pred, t, latents).prev_sample\n",
|
163 |
-
" \n",
|
164 |
-
" latents = 1 / 0.18215 * latents\n",
|
165 |
-
" image = vae.decode(latents).sample\n",
|
166 |
-
" image = (image / 2 + 0.5).clamp(0, 1)\n",
|
167 |
-
" image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]\n",
|
168 |
-
"\n",
|
169 |
-
" image = Image.fromarray((image * 255).round().astype(\"uint8\"))\n",
|
170 |
-
"\n",
|
171 |
-
" return [image] "
|
172 |
-
]
|
173 |
-
},
|
174 |
-
{
|
175 |
-
"cell_type": "code",
|
176 |
-
"execution_count": 8,
|
177 |
-
"metadata": {},
|
178 |
-
"outputs": [
|
179 |
-
{
|
180 |
-
"name": "stdout",
|
181 |
-
"output_type": "stream",
|
182 |
-
"text": [
|
183 |
-
"Running on local URL: http://127.0.0.1:7860\n",
|
184 |
-
"Running on public URL: https://bc89b27b9704787832.gradio.live\n",
|
185 |
-
"\n",
|
186 |
-
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
187 |
-
]
|
188 |
-
},
|
189 |
-
{
|
190 |
-
"data": {
|
191 |
-
"text/html": [
|
192 |
-
"<div><iframe src=\"https://bc89b27b9704787832.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
193 |
-
],
|
194 |
-
"text/plain": [
|
195 |
-
"<IPython.core.display.HTML object>"
|
196 |
-
]
|
197 |
-
},
|
198 |
-
"metadata": {},
|
199 |
-
"output_type": "display_data"
|
200 |
-
},
|
201 |
-
{
|
202 |
-
"data": {
|
203 |
-
"text/plain": []
|
204 |
-
},
|
205 |
-
"execution_count": 8,
|
206 |
-
"metadata": {},
|
207 |
-
"output_type": "execute_result"
|
208 |
-
},
|
209 |
-
{
|
210 |
-
"name": "stderr",
|
211 |
-
"output_type": "stream",
|
212 |
-
"text": [
|
213 |
-
"Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 8.95it/s]\n",
|
214 |
-
"Traceback (most recent call last):\n",
|
215 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/routes.py\", line 437, in run_predict\n",
|
216 |
-
" output = await app.get_blocks().process_api(\n",
|
217 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py\", line 1352, in process_api\n",
|
218 |
-
" result = await self.call_function(\n",
|
219 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py\", line 1077, in call_function\n",
|
220 |
-
" prediction = await anyio.to_thread.run_sync(\n",
|
221 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
|
222 |
-
" return await get_async_backend().run_sync_in_worker_thread(\n",
|
223 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 2134, in run_sync_in_worker_thread\n",
|
224 |
-
" return await future\n",
|
225 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 851, in run\n",
|
226 |
-
" result = context.run(func, *args)\n",
|
227 |
-
" File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/torch/utils/_contextlib.py\", line 115, in decorate_context\n",
|
228 |
-
" return func(*args, **kwargs)\n",
|
229 |
-
" File \"/tmp/ipykernel_2844069/1186401021.py\", line 12, in inference\n",
|
230 |
-
" (1, unet.in_channels, 512 // 8, 512 // 8),\n",
|
231 |
-
"NameError: name 'unet' is not defined\n"
|
232 |
-
]
|
233 |
-
},
|
234 |
-
{
|
235 |
-
"name": "stdout",
|
236 |
-
"output_type": "stream",
|
237 |
-
"text": [
|
238 |
-
"\n"
|
239 |
-
]
|
240 |
-
}
|
241 |
-
],
|
242 |
-
"source": [
|
243 |
-
"css = ''\n",
|
244 |
-
"with gr.Blocks(css=css) as demo:\n",
|
245 |
-
" gr.Markdown(\"# <em>weights2weights</em> Demo\")\n",
|
246 |
-
" gr.Markdown(\"Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co/h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license\")\n",
|
247 |
-
" with gr.Row():\n",
|
248 |
-
" with gr.Column():\n",
|
249 |
-
" files = gr.Files(\n",
|
250 |
-
" label=\"Upload a photo of your face to invert, or sample a new model\",\n",
|
251 |
-
" file_types=[\"image\"]\n",
|
252 |
-
" )\n",
|
253 |
-
" uploaded_files = gr.Gallery(label=\"Your images\", visible=False, columns=5, rows=1, height=125)\n",
|
254 |
-
"\n",
|
255 |
-
" sample = gr.Button(\"Sample New Model\")\n",
|
256 |
-
"\n",
|
257 |
-
" with gr.Column(visible=False) as clear_button:\n",
|
258 |
-
" remove_and_reupload = gr.ClearButton(value=\"Remove and upload new ones\", components=files, size=\"sm\")\n",
|
259 |
-
" prompt = gr.Textbox(label=\"Prompt\",\n",
|
260 |
-
" info=\"Make sure to include 'sks person'\" ,\n",
|
261 |
-
" placeholder=\"sks person\", \n",
|
262 |
-
" value=\"sks person\")\n",
|
263 |
-
" negative_prompt = gr.Textbox(label=\"Negative Prompt\", placeholder=\"low quality, blurry, unfinished, cartoon\", value=\"low quality, blurry, unfinished, cartoon\")\n",
|
264 |
-
" seed = gr.Number(value=5, precision=0, label=\"Seed\", interactive=True)\n",
|
265 |
-
" cfg = gr.Slider(label=\"CFG\", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)\n",
|
266 |
-
" steps = gr.Slider(label=\"Inference Steps\", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)\n",
|
267 |
-
"\n",
|
268 |
-
"\n",
|
269 |
-
" submit = gr.Button(\"Submit\")\n",
|
270 |
-
"\n",
|
271 |
-
" with gr.Column():\n",
|
272 |
-
" gallery = gr.Gallery(label=\"Generated Images\")\n",
|
273 |
-
"\n",
|
274 |
-
" sample.click(fn=sample_model)\n",
|
275 |
-
" \n",
|
276 |
-
" submit.click(fn=inference,\n",
|
277 |
-
" inputs=[prompt, negative_prompt, cfg, steps, seed],\n",
|
278 |
-
" outputs=gallery)\n",
|
279 |
-
" \n",
|
280 |
-
"\n",
|
281 |
-
"\n",
|
282 |
-
"\n",
|
283 |
-
" \n",
|
284 |
-
" \n",
|
285 |
-
"demo.launch(share=True)"
|
286 |
-
]
|
287 |
-
},
|
288 |
-
{
|
289 |
-
"cell_type": "code",
|
290 |
-
"execution_count": null,
|
291 |
-
"metadata": {},
|
292 |
-
"outputs": [],
|
293 |
-
"source": []
|
294 |
-
}
|
295 |
-
],
|
296 |
-
"metadata": {
|
297 |
-
"kernelspec": {
|
298 |
-
"display_name": "dblora2",
|
299 |
-
"language": "python",
|
300 |
-
"name": "python3"
|
301 |
-
},
|
302 |
-
"language_info": {
|
303 |
-
"codemirror_mode": {
|
304 |
-
"name": "ipython",
|
305 |
-
"version": 3
|
306 |
-
},
|
307 |
-
"file_extension": ".py",
|
308 |
-
"mimetype": "text/x-python",
|
309 |
-
"name": "python",
|
310 |
-
"nbconvert_exporter": "python",
|
311 |
-
"pygments_lexer": "ipython3",
|
312 |
-
"version": "3.8.18"
|
313 |
-
}
|
314 |
-
},
|
315 |
-
"nbformat": 4,
|
316 |
-
"nbformat_minor": 2
|
317 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|