amildravid4292 commited on
Commit
961c222
·
verified ·
1 Parent(s): df21382

Delete gradio.ipynb

Browse files
Files changed (1) hide show
  1. gradio.ipynb +0 -317
gradio.ipynb DELETED
@@ -1,317 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stderr",
10
- "output_type": "stream",
11
- "text": [
12
- "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
- " from .autonotebook import tqdm as notebook_tqdm\n"
14
- ]
15
- },
16
- {
17
- "name": "stdout",
18
- "output_type": "stream",
19
- "text": [
20
- "[2024-06-28 00:45:26,702] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
21
- ]
22
- }
23
- ],
24
- "source": [
25
- "import gradio as gr\n",
26
- "import sys\n",
27
- "import os \n",
28
- "import tqdm\n",
29
- "sys.path.append(os.path.abspath(os.path.join(\"\", \"..\")))\n",
30
- "import torch\n",
31
- "import gc\n",
32
- "import warnings\n",
33
- "warnings.filterwarnings(\"ignore\")\n",
34
- "from PIL import Image\n",
35
- "from utils import load_models, save_model_w2w, save_model_for_diffusers\n",
36
- "from sampling import sample_weights"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": 2,
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "global device\n",
46
- "global generator \n",
47
- "global unet\n",
48
- "global vae \n",
49
- "global text_encoder\n",
50
- "global tokenizer\n",
51
- "global noise_scheduler\n",
52
- "device = \"cuda:0\"\n",
53
- "generator = torch.Generator(device=device)"
54
- ]
55
- },
56
- {
57
- "cell_type": "code",
58
- "execution_count": 3,
59
- "metadata": {},
60
- "outputs": [],
61
- "source": [
62
- "mean = torch.load(\"files/mean.pt\").bfloat16().to(device)\n",
63
- "std = torch.load(\"files/std.pt\").bfloat16().to(device)\n",
64
- "v = torch.load(\"files/V.pt\").bfloat16().to(device)\n",
65
- "proj = torch.load(\"files/proj_1000pc.pt\").bfloat16().to(device)\n",
66
- "df = torch.load(\"files/identity_df.pt\")\n",
67
- "weight_dimensions = torch.load(\"files/weight_dimensions.pt\")"
68
- ]
69
- },
70
- {
71
- "cell_type": "code",
72
- "execution_count": 4,
73
- "metadata": {},
74
- "outputs": [
75
- {
76
- "name": "stderr",
77
- "output_type": "stream",
78
- "text": [
79
- "Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.79it/s]\n"
80
- ]
81
- },
82
- {
83
- "name": "stdout",
84
- "output_type": "stream",
85
- "text": [
86
- "\n"
87
- ]
88
- }
89
- ],
90
- "source": [
91
- "unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)"
92
- ]
93
- },
94
- {
95
- "cell_type": "code",
96
- "execution_count": 5,
97
- "metadata": {},
98
- "outputs": [],
99
- "source": [
100
- "global network"
101
- ]
102
- },
103
- {
104
- "cell_type": "code",
105
- "execution_count": 6,
106
- "metadata": {},
107
- "outputs": [],
108
- "source": [
109
- "def sample_model():\n",
110
- " global unet\n",
111
- " del unet\n",
112
- " global network\n",
113
- " unet, _, _, _, _ = load_models(device)\n",
114
- " network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)\n",
115
- " \n"
116
- ]
117
- },
118
- {
119
- "cell_type": "code",
120
- "execution_count": 7,
121
- "metadata": {},
122
- "outputs": [],
123
- "source": [
124
- "@torch.no_grad()\n",
125
- "def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):\n",
126
- " global device\n",
127
- " global generator \n",
128
- " global unet\n",
129
- " global vae \n",
130
- " global text_encoder\n",
131
- " global tokenizer\n",
132
- " global noise_scheduler\n",
133
- " generator = generator.manual_seed(seed)\n",
134
- " latents = torch.randn(\n",
135
- " (1, unet.in_channels, 512 // 8, 512 // 8),\n",
136
- " generator = generator,\n",
137
- " device = device\n",
138
- " ).bfloat16()\n",
139
- " \n",
140
- "\n",
141
- " text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n",
142
- "\n",
143
- " text_embeddings = text_encoder(text_input.input_ids.to(device))[0]\n",
144
- "\n",
145
- " max_length = text_input.input_ids.shape[-1]\n",
146
- " uncond_input = tokenizer(\n",
147
- " [negative_prompt], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n",
148
- " )\n",
149
- " uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]\n",
150
- " text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n",
151
- " noise_scheduler.set_timesteps(ddim_steps) \n",
152
- " latents = latents * noise_scheduler.init_noise_sigma\n",
153
- " \n",
154
- " for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):\n",
155
- " latent_model_input = torch.cat([latents] * 2)\n",
156
- " latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)\n",
157
- " with network:\n",
158
- " noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample\n",
159
- " #guidance\n",
160
- " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
161
- " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n",
162
- " latents = noise_scheduler.step(noise_pred, t, latents).prev_sample\n",
163
- " \n",
164
- " latents = 1 / 0.18215 * latents\n",
165
- " image = vae.decode(latents).sample\n",
166
- " image = (image / 2 + 0.5).clamp(0, 1)\n",
167
- " image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]\n",
168
- "\n",
169
- " image = Image.fromarray((image * 255).round().astype(\"uint8\"))\n",
170
- "\n",
171
- " return [image] "
172
- ]
173
- },
174
- {
175
- "cell_type": "code",
176
- "execution_count": 8,
177
- "metadata": {},
178
- "outputs": [
179
- {
180
- "name": "stdout",
181
- "output_type": "stream",
182
- "text": [
183
- "Running on local URL: http://127.0.0.1:7860\n",
184
- "Running on public URL: https://bc89b27b9704787832.gradio.live\n",
185
- "\n",
186
- "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
187
- ]
188
- },
189
- {
190
- "data": {
191
- "text/html": [
192
- "<div><iframe src=\"https://bc89b27b9704787832.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
193
- ],
194
- "text/plain": [
195
- "<IPython.core.display.HTML object>"
196
- ]
197
- },
198
- "metadata": {},
199
- "output_type": "display_data"
200
- },
201
- {
202
- "data": {
203
- "text/plain": []
204
- },
205
- "execution_count": 8,
206
- "metadata": {},
207
- "output_type": "execute_result"
208
- },
209
- {
210
- "name": "stderr",
211
- "output_type": "stream",
212
- "text": [
213
- "Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 8.95it/s]\n",
214
- "Traceback (most recent call last):\n",
215
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/routes.py\", line 437, in run_predict\n",
216
- " output = await app.get_blocks().process_api(\n",
217
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py\", line 1352, in process_api\n",
218
- " result = await self.call_function(\n",
219
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py\", line 1077, in call_function\n",
220
- " prediction = await anyio.to_thread.run_sync(\n",
221
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
222
- " return await get_async_backend().run_sync_in_worker_thread(\n",
223
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 2134, in run_sync_in_worker_thread\n",
224
- " return await future\n",
225
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 851, in run\n",
226
- " result = context.run(func, *args)\n",
227
- " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/torch/utils/_contextlib.py\", line 115, in decorate_context\n",
228
- " return func(*args, **kwargs)\n",
229
- " File \"/tmp/ipykernel_2844069/1186401021.py\", line 12, in inference\n",
230
- " (1, unet.in_channels, 512 // 8, 512 // 8),\n",
231
- "NameError: name 'unet' is not defined\n"
232
- ]
233
- },
234
- {
235
- "name": "stdout",
236
- "output_type": "stream",
237
- "text": [
238
- "\n"
239
- ]
240
- }
241
- ],
242
- "source": [
243
- "css = ''\n",
244
- "with gr.Blocks(css=css) as demo:\n",
245
- " gr.Markdown(\"# <em>weights2weights</em> Demo\")\n",
246
- " gr.Markdown(\"Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co/h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license\")\n",
247
- " with gr.Row():\n",
248
- " with gr.Column():\n",
249
- " files = gr.Files(\n",
250
- " label=\"Upload a photo of your face to invert, or sample a new model\",\n",
251
- " file_types=[\"image\"]\n",
252
- " )\n",
253
- " uploaded_files = gr.Gallery(label=\"Your images\", visible=False, columns=5, rows=1, height=125)\n",
254
- "\n",
255
- " sample = gr.Button(\"Sample New Model\")\n",
256
- "\n",
257
- " with gr.Column(visible=False) as clear_button:\n",
258
- " remove_and_reupload = gr.ClearButton(value=\"Remove and upload new ones\", components=files, size=\"sm\")\n",
259
- " prompt = gr.Textbox(label=\"Prompt\",\n",
260
- " info=\"Make sure to include 'sks person'\" ,\n",
261
- " placeholder=\"sks person\", \n",
262
- " value=\"sks person\")\n",
263
- " negative_prompt = gr.Textbox(label=\"Negative Prompt\", placeholder=\"low quality, blurry, unfinished, cartoon\", value=\"low quality, blurry, unfinished, cartoon\")\n",
264
- " seed = gr.Number(value=5, precision=0, label=\"Seed\", interactive=True)\n",
265
- " cfg = gr.Slider(label=\"CFG\", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)\n",
266
- " steps = gr.Slider(label=\"Inference Steps\", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)\n",
267
- "\n",
268
- "\n",
269
- " submit = gr.Button(\"Submit\")\n",
270
- "\n",
271
- " with gr.Column():\n",
272
- " gallery = gr.Gallery(label=\"Generated Images\")\n",
273
- "\n",
274
- " sample.click(fn=sample_model)\n",
275
- " \n",
276
- " submit.click(fn=inference,\n",
277
- " inputs=[prompt, negative_prompt, cfg, steps, seed],\n",
278
- " outputs=gallery)\n",
279
- " \n",
280
- "\n",
281
- "\n",
282
- "\n",
283
- " \n",
284
- " \n",
285
- "demo.launch(share=True)"
286
- ]
287
- },
288
- {
289
- "cell_type": "code",
290
- "execution_count": null,
291
- "metadata": {},
292
- "outputs": [],
293
- "source": []
294
- }
295
- ],
296
- "metadata": {
297
- "kernelspec": {
298
- "display_name": "dblora2",
299
- "language": "python",
300
- "name": "python3"
301
- },
302
- "language_info": {
303
- "codemirror_mode": {
304
- "name": "ipython",
305
- "version": 3
306
- },
307
- "file_extension": ".py",
308
- "mimetype": "text/x-python",
309
- "name": "python",
310
- "nbconvert_exporter": "python",
311
- "pygments_lexer": "ipython3",
312
- "version": "3.8.18"
313
- }
314
- },
315
- "nbformat": 4,
316
- "nbformat_minor": 2
317
- }