multimodalart HF staff commited on
Commit
a1b8e4d
1 Parent(s): 8772583

Create new file

Browse files
Files changed (1) hide show
  1. app.py +245 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Prepare the Concepts Library to be used
2
+
3
+ import requests
4
+ import os
5
+ import gradio as gr
6
+ import wget
7
+ import torch
8
+ from torch import autocast
9
+ from diffusers import StableDiffusionPipeline
10
+ from huggingface_hub import HfApi
11
+ from transformers import CLIPTextModel, CLIPTokenizer
12
+ from tqdm.notebook import tqdm
13
+
14
+ api = HfApi()
15
+ models_list = api.list_models(author="sd-concepts-library")
16
+ models = []
17
+
18
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True, revision="fp16", torch_dtype=torch.float16).to("cuda")
19
+
20
+ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
21
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
22
+
23
+ # separate token and the embeds
24
+ trained_token = list(loaded_learned_embeds.keys())[0]
25
+ embeds = loaded_learned_embeds[trained_token]
26
+
27
+ # cast to dtype of text_encoder
28
+ dtype = text_encoder.get_input_embeddings().weight.dtype
29
+ embeds.to(dtype)
30
+
31
+ # add the token in tokenizer
32
+ token = token if token is not None else trained_token
33
+ num_added_tokens = tokenizer.add_tokens(token)
34
+ i = 1
35
+ while(num_added_tokens == 0):
36
+ print(f"The tokenizer already contains the token {token}.")
37
+ token = f"{token[:-1]}-{i}>"
38
+ print(f"Attempting to add the token {token}.")
39
+ num_added_tokens = tokenizer.add_tokens(token)
40
+ i+=1
41
+
42
+ # resize the token embeddings
43
+ text_encoder.resize_token_embeddings(len(tokenizer))
44
+
45
+ # get the id for the token and assign the embeds
46
+ token_id = tokenizer.convert_tokens_to_ids(token)
47
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
48
+ return token
49
+
50
+ print("Setting up the public library")
51
+ for model in tqdm(models_list):
52
+ model_content = {}
53
+ model_id = model.modelId
54
+ model_content["id"] = model_id
55
+ embeds_url = f"https://huggingface.co/{model_id}/resolve/main/learned_embeds.bin"
56
+ os.makedirs(model_id,exist_ok = True)
57
+ if not os.path.exists(f"{model_id}/learned_embeds.bin"):
58
+ try:
59
+ wget.download(embeds_url, out=model_id)
60
+ except:
61
+ continue
62
+ token_identifier = f"https://huggingface.co/{model_id}/raw/main/token_identifier.txt"
63
+ response = requests.get(token_identifier)
64
+ token_name = response.text
65
+
66
+ concept_type = f"https://huggingface.co/{model_id}/raw/main/type_of_concept.txt"
67
+ response = requests.get(concept_type)
68
+ concept_name = response.text
69
+ model_content["concept_type"] = concept_name
70
+ images = []
71
+ for i in range(4):
72
+ url = f"https://huggingface.co/{model_id}/resolve/main/concept_images/{i}.jpeg"
73
+ image_download = requests.get(url)
74
+ url_code = image_download.status_code
75
+ if(url_code == 200):
76
+ file = open(f"{model_id}/{i}.jpeg", "wb") ## Creates the file for image
77
+ file.write(image_download.content) ## Saves file content
78
+ file.close()
79
+ images.append(f"{model_id}/{i}.jpeg")
80
+ model_content["images"] = images
81
+
82
+ learned_token = load_learned_embed_in_clip(f"{model_id}/learned_embeds.bin", pipe.text_encoder, pipe.tokenizer, token_name)
83
+ model_content["token"] = learned_token
84
+ models.append(model_content)
85
+
86
+ #@title Run the app to navigate around [the Library](https://huggingface.co/sd-concepts-library)
87
+ #@markdown Click the `Running on public URL:` result to run the Gradio app
88
+
89
+ SELECT_LABEL = "Select concept"
90
+
91
+ def title_block(title, id):
92
+ return gr.Markdown(f"### [`{title}`](https://huggingface.co/{id})")
93
+
94
+ def image_block(image_list, concept_type):
95
+ return gr.Gallery(
96
+ label=concept_type, value=image_list, elem_id="gallery"
97
+ ).style(grid=[2], height="auto")
98
+
99
+ def checkbox_block():
100
+ checkbox = gr.Checkbox(label=SELECT_LABEL).style(container=False)
101
+ return checkbox
102
+
103
+ def infer(text):
104
+ with autocast("cuda"):
105
+ images_list = pipe(
106
+ [text]*2,
107
+ num_inference_steps=50,
108
+ guidance_scale=7.5
109
+ )
110
+ output_images = []
111
+ for i, image in enumerate(images_list["sample"]):
112
+ output_images.append(image)
113
+ return output_images
114
+
115
+ css = '''
116
+ .gradio-container {font-family: 'IBM Plex Sans', sans-serif}
117
+ #top_title{margin-bottom: .5em}
118
+ #top_title h2{margin-bottom: 0; text-align: center}
119
+ #main_row{flex-wrap: wrap; gap: 1em; max-height: calc(100vh - 16em); overflow-y: scroll; flex-direction: row}
120
+ @media (min-width: 768px){#main_row > div{flex: 1 1 32%; margin-left: 0 !important}}
121
+ .gr-prose code::before, .gr-prose code::after {content: "" !important}
122
+ ::-webkit-scrollbar {width: 10px}
123
+ ::-webkit-scrollbar-track {background: #f1f1f1}
124
+ ::-webkit-scrollbar-thumb {background: #888}
125
+ ::-webkit-scrollbar-thumb:hover {background: #555}
126
+ .gr-button {white-space: nowrap}
127
+ .gr-button:focus {
128
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
129
+ outline: none;
130
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
131
+ --tw-border-opacity: 1;
132
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
133
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
134
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
135
+ --tw-ring-opacity: .5;
136
+ }
137
+ #prompt_input{flex: 1 3 auto}
138
+ #prompt_area{margin-bottom: .75em}
139
+ #prompt_area > div:first-child{flex: 1 3 auto}
140
+ '''
141
+ examples = ["a <cat-toy> in <madhubani-art> style", "a mecha robot in <line-art> style", "a piano being played by <bonzi>"]
142
+ with gr.Blocks(css=css) as demo:
143
+ state = gr.Variable({
144
+ 'selected': -1
145
+ })
146
+ state = {}
147
+ def update_state(i):
148
+ global checkbox_states
149
+ if(checkbox_states[i]):
150
+ checkbox_states[i] = False
151
+ state[i] = False
152
+ else:
153
+ state[i] = True
154
+ checkbox_states[i] = True
155
+ gr.HTML('''
156
+ <div style="text-align: center; max-width: 720px; margin: 0 auto;">
157
+ <div
158
+ style="
159
+ display: inline-flex;
160
+ align-items: center;
161
+ gap: 0.8rem;
162
+ font-size: 1.75rem;
163
+ "
164
+ >
165
+ <svg
166
+ width="0.65em"
167
+ height="0.65em"
168
+ viewBox="0 0 115 115"
169
+ fill="none"
170
+ xmlns="http://www.w3.org/2000/svg"
171
+ >
172
+ <rect width="23" height="23" fill="white"></rect>
173
+ <rect y="69" width="23" height="23" fill="white"></rect>
174
+ <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
175
+ <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
176
+ <rect x="46" width="23" height="23" fill="white"></rect>
177
+ <rect x="46" y="69" width="23" height="23" fill="white"></rect>
178
+ <rect x="69" width="23" height="23" fill="black"></rect>
179
+ <rect x="69" y="69" width="23" height="23" fill="black"></rect>
180
+ <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
181
+ <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
182
+ <rect x="115" y="46" width="23" height="23" fill="white"></rect>
183
+ <rect x="115" y="115" width="23" height="23" fill="white"></rect>
184
+ <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
185
+ <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
186
+ <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
187
+ <rect x="92" y="69" width="23" height="23" fill="white"></rect>
188
+ <rect x="69" y="46" width="23" height="23" fill="white"></rect>
189
+ <rect x="69" y="115" width="23" height="23" fill="white"></rect>
190
+ <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
191
+ <rect x="46" y="46" width="23" height="23" fill="black"></rect>
192
+ <rect x="46" y="115" width="23" height="23" fill="black"></rect>
193
+ <rect x="46" y="69" width="23" height="23" fill="black"></rect>
194
+ <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
195
+ <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
196
+ <rect x="23" y="69" width="23" height="23" fill="black"></rect>
197
+ </svg>
198
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
199
+ Stable Diffusion Conceptualizer
200
+ </h1>
201
+ </div>
202
+ <p style="margin-bottom: 10px; font-size: 94%">
203
+ Navigate through community created concepts and styles via Stable Diffusion Textual Inversion and pick yours for inference.
204
+ To train your own concepts and contribute to the library <a style="text-decoration: underline" href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb">check out this notebook</a>.
205
+ </p>
206
+ </div>
207
+ ''')
208
+ with gr.Row():
209
+ with gr.Column():
210
+ gr.Markdown('''
211
+ ### Textual-Inversion trained [concepts library](https://huggingface.co/sd-concepts-library) navigator
212
+ ''')
213
+ with gr.Row(elem_id="main_row"):
214
+ image_blocks = []
215
+ for i, model in enumerate(models):
216
+ with gr.Box().style(border=None):
217
+ title_block(model["token"], model["id"])
218
+ image_blocks.append(image_block(model["images"], model["concept_type"]))
219
+ with gr.Box():
220
+ with gr.Row(elem_id="prompt_area").style(mobile_collapse=False, equal_height=True):
221
+ text = gr.Textbox(
222
+ label="Enter your prompt", placeholder="Enter your prompt", show_label=False, max_lines=1, elem_id="prompt_input"
223
+ ).style(
224
+ border=(True, False, True, True),
225
+ rounded=(True, False, False, True),
226
+ container=False
227
+ )
228
+ btn = gr.Button("Run",elem_id="run_btn").style(
229
+ margin=False,
230
+ rounded=(False, True, True, False)
231
+ )
232
+ with gr.Row().style():
233
+ infer_outputs = gr.Gallery(show_label=False).style(grid=[2], height="512px")
234
+ with gr.Row():
235
+ gr.HTML("<p style=\"font-size: 85%;margin-top: .75em\">Prompting may not work as you are used to; <code>objects</code> may need the concept added at the end.</p>")
236
+ with gr.Row():
237
+ gr.Examples(examples=examples, fn=infer, inputs=[text], outputs=infer_outputs, cache_examples=False)
238
+ checkbox_states = {}
239
+ inputs = [text]
240
+ btn.click(
241
+ infer,
242
+ inputs=inputs,
243
+ outputs=infer_outputs
244
+ )
245
+ demo.launch(inline=False, debug=True)