medmac01 commited on
Commit
ff09d94
·
verified ·
1 Parent(s): b95e196

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +5 -5
  2. app.py +61 -0
  3. model.py +187 -0
  4. requirements.txt +5 -0
  5. safety_checker.py +137 -0
  6. style.css +12 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Stable Diff Multilingual V0.1
3
- emoji: 😻
4
  colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.20.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: SDXL Lightning
3
+ emoji:
4
  colorFrom: yellow
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ from safetensors.torch import load_file
4
+ from PIL import Image
5
+
6
+ from model import *
7
+
8
+ # SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
9
+
10
+ # Constants
11
+ # base = "stabilityai/stable-diffusion-xl-base-1.0"
12
+ # repo = "ByteDance/SDXL-Lightning"
13
+ # checkpoints = {
14
+ # "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
15
+ # "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
16
+ # "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
17
+ # "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
18
+ # }
19
+ # loaded = None
20
+
21
+
22
+ # Ensure model and scheduler are initialized in GPU-enabled function
23
+ # if torch.cuda.is_available():
24
+ # pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
25
+
26
+
27
+ # Function
28
+ # @spaces.GPU(enable_queue=True)
29
+
30
+ def generate_image(prompt):
31
+
32
+ return prompt_to_img(prompt)[0]
33
+
34
+
35
+
36
+ # Gradio Interface
37
+ description = """
38
+ This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
39
+ As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
40
+ """
41
+
42
+ with gr.Blocks(css="style.css") as demo:
43
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
44
+ gr.Markdown(description)
45
+ with gr.Group():
46
+ with gr.Row():
47
+ prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
48
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
49
+ submit = gr.Button(scale=1, variant='primary')
50
+ img = gr.Image(label='SDXL-Lightning Generated Image')
51
+
52
+ prompt.submit(fn=generate_image,
53
+ inputs=[prompt],
54
+ outputs=img,
55
+ )
56
+ submit.click(fn=generate_image,
57
+ inputs=[prompt],
58
+ outputs=img,
59
+ )
60
+
61
+ demo.queue().launch()
model.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageDraw
3
+ import cv2
4
+ import numpy as np
5
+ from IPython.display import HTML
6
+ from base64 import b64encode
7
+
8
+ import torch
9
+ from torch import autocast
10
+ from torch.nn import functional as F
11
+ from diffusers import StableDiffusionPipeline, AutoencoderKL
12
+ from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler
13
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
14
+ #from transformers import CLIPTextModel, CLIPTokenizer
15
+ from tqdm.auto import tqdm
16
+ from huggingface_hub import notebook_login
17
+
18
+ import weights
19
+
20
+ device = 'cpu'
21
+
22
+ from Multilingual_CLIP.multilingual_clip import Config_MCLIP
23
+ import transformers
24
+ import torch
25
+
26
+
27
+ class MultilingualCLIP(transformers.PreTrainedModel):
28
+ config_class = Config_MCLIP.MCLIPConfig
29
+
30
+ def __init__(self, config, *args, **kwargs):
31
+ super().__init__(config, *args, **kwargs)
32
+ self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
33
+ self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
34
+ out_features=config.numDims)
35
+
36
+ def forward(self, txt, tokenizer, device):
37
+ txt_tok = tokenizer(txt, padding='max_length', max_length=77, truncation=True, return_tensors='pt').to(device)
38
+ embs = self.transformer(**txt_tok)
39
+ embs = embs[0]
40
+ att = txt_tok['attention_mask']
41
+ embs = (embs * att.unsqueeze(2)) / att.sum(dim=1)[:, None].unsqueeze(2)
42
+ return self.LinearTransformation(embs)
43
+
44
+ @classmethod
45
+ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
46
+ model.load_state_dict(state_dict)
47
+ return model, [], [], []
48
+
49
+
50
+ import torch
51
+ import torch.nn as nn
52
+
53
+ # Define the adaptation layer, 'checkpoint_9.pth'
54
+ class AdaptationLayer(nn.Module):
55
+ def __init__(self, input_dim, output_dim):
56
+ super(AdaptationLayer, self).__init__()
57
+ self.fc1 = nn.Linear(input_dim, output_dim*2)
58
+ torch.nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')
59
+ self.bn1 = nn.BatchNorm1d(77)
60
+
61
+ self.fc2 = nn.Linear(input_dim*2, output_dim*2)
62
+ torch.nn.init.kaiming_uniform_(self.fc2.weight, nonlinearity='relu')
63
+ self.bn2 = nn.BatchNorm1d(77)
64
+
65
+ self.fc3 = nn.Linear(input_dim*2, output_dim)
66
+ torch.nn.init.kaiming_uniform_(self.fc3.weight, nonlinearity='relu')
67
+ self.bn3 = nn.BatchNorm1d(77)
68
+
69
+ self.fc4 = nn.Linear(input_dim, output_dim)
70
+ torch.nn.init.kaiming_uniform_(self.fc4.weight, nonlinearity='relu')
71
+ self.bn4 = nn.BatchNorm1d(77)
72
+
73
+ self.fc5 = nn.Linear(input_dim, output_dim)
74
+
75
+ def forward(self, x):
76
+ x = nn.functional.normalize(x, p=2.0, dim=1, eps=1e-12, out=None)
77
+ x = torch.relu(self.bn1(self.fc1(x)))
78
+ x = torch.relu(self.bn2(self.fc2(x)))
79
+ x = torch.relu(self.bn3(self.fc3(x)))
80
+ x = torch.relu(self.bn4(self.fc4(x)))
81
+
82
+ return self.fc5(x)
83
+
84
+
85
+ adapt_model = AdaptationLayer(768,768)
86
+ adapt_model.to(device)
87
+ state_dict = torch.load('weights/checkpoint_9.pth')
88
+ adapt_model.load_state_dict(state_dict)
89
+
90
+ # 1. Load the autoencoder model which will be used to decode the latents into image space.
91
+ vae = AutoencoderKL.from_pretrained(
92
+ 'CompVis/stable-diffusion-v1-4', subfolder='vae', use_auth_token=True)
93
+ vae = vae.to(device)
94
+
95
+ # 2. Load the tokenizer and text encoder to tokenize and encode the text.
96
+ tokenizer = text_tokenizer
97
+ text_encoder = text_model
98
+
99
+ # 3. The UNet model for generating the latents.
100
+ unet = UNet2DConditionModel.from_pretrained(
101
+ 'CompVis/stable-diffusion-v1-4', subfolder='unet', use_auth_token=True)
102
+ unet = unet.to(device)
103
+
104
+ # 4. Create a scheduler for inference
105
+ scheduler = LMSDiscreteScheduler(
106
+ beta_start=0.00085, beta_end=0.012,
107
+ beta_schedule='scaled_linear', num_train_timesteps=1000)
108
+
109
+
110
+ def get_text_embeds(prompt):
111
+ with torch.no_grad():
112
+ text_embeddings = text_model(prompt, text_tokenizer, device)
113
+ text_embeddings = adapt_model(text_embeddings)
114
+
115
+ # Do the same for unconditional embeddings
116
+ with torch.no_grad():
117
+ uncond_embeddings = text_model([''] * len(prompt), text_tokenizer, device)
118
+ uncond_embeddings = adapt_model(uncond_embeddings)
119
+
120
+ # Cat for final embeddings
121
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
122
+ return text_embeddings
123
+
124
+
125
+ def produce_latents(text_embeddings, height=512, width=512,
126
+ num_inference_steps=50, guidance_scale=7.5, latents=None):
127
+ if latents is None:
128
+ latents = torch.randn((text_embeddings.shape[0] // 2, unet.in_channels, \
129
+ height // 8, width // 8))
130
+ latents = latents.to(device)
131
+
132
+ scheduler.set_timesteps(num_inference_steps)
133
+ latents = latents * scheduler.sigmas[0]
134
+
135
+ with autocast('cpu'):
136
+ for i, t in tqdm(enumerate(scheduler.timesteps)):
137
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
138
+ latent_model_input = torch.cat([latents] * 2)
139
+ sigma = scheduler.sigmas[i]
140
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
141
+
142
+ # predict the noise residual
143
+ with torch.no_grad():
144
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings.to(device))['sample']
145
+
146
+ # perform guidance
147
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
148
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
149
+
150
+ # compute the previous noisy sample x_t -> x_t-1
151
+ latents = scheduler.step(noise_pred, i, latents)['prev_sample']
152
+
153
+ return latents
154
+
155
+
156
+ def decode_img_latents(latents):
157
+ latents = 1 / 0.18215 * latents
158
+
159
+ with torch.no_grad():
160
+ imgs = vae.decode(latents)
161
+
162
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
163
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
164
+ imgs = (imgs * 255).round().astype('uint8')
165
+ pil_images = [Image.fromarray(image) for image in imgs]
166
+ return pil_images
167
+
168
+ def prompt_to_img(prompts, height=512, width=512, num_inference_steps=50,
169
+ guidance_scale=7.5, latents=None):
170
+ if isinstance(prompts, str):
171
+ prompts = [prompts]
172
+
173
+ # Prompts -> text embeds
174
+ text_embeds = get_text_embeds(prompts)
175
+
176
+ # Text embeds -> img latents
177
+ latents = produce_latents(
178
+ text_embeds, height=height, width=width, latents=latents,
179
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)
180
+
181
+ # Img latents -> imgs
182
+ imgs = decode_img_latents(latents)
183
+
184
+ return imgs
185
+
186
+
187
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ torch
4
+ accelerate
5
+ gradio
safety_checker.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
19
+
20
+
21
+ def cosine_distance(image_embeds, text_embeds):
22
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
23
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
24
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
25
+
26
+
27
+ class StableDiffusionSafetyChecker(PreTrainedModel):
28
+ config_class = CLIPConfig
29
+
30
+ _no_split_modules = ["CLIPEncoderLayer"]
31
+
32
+ def __init__(self, config: CLIPConfig):
33
+ super().__init__(config)
34
+
35
+ self.vision_model = CLIPVisionModel(config.vision_config)
36
+ self.visual_projection = nn.Linear(
37
+ config.vision_config.hidden_size, config.projection_dim, bias=False
38
+ )
39
+
40
+ self.concept_embeds = nn.Parameter(
41
+ torch.ones(17, config.projection_dim), requires_grad=False
42
+ )
43
+ self.special_care_embeds = nn.Parameter(
44
+ torch.ones(3, config.projection_dim), requires_grad=False
45
+ )
46
+
47
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
48
+ self.special_care_embeds_weights = nn.Parameter(
49
+ torch.ones(3), requires_grad=False
50
+ )
51
+
52
+ @torch.no_grad()
53
+ def forward(self, clip_input, images):
54
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
55
+ image_embeds = self.visual_projection(pooled_output)
56
+
57
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
58
+ special_cos_dist = (
59
+ cosine_distance(image_embeds, self.special_care_embeds)
60
+ .cpu()
61
+ .float()
62
+ .numpy()
63
+ )
64
+ cos_dist = (
65
+ cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
66
+ )
67
+
68
+ result = []
69
+ batch_size = image_embeds.shape[0]
70
+ for i in range(batch_size):
71
+ result_img = {
72
+ "special_scores": {},
73
+ "special_care": [],
74
+ "concept_scores": {},
75
+ "bad_concepts": [],
76
+ }
77
+
78
+ # increase this value to create a stronger `nfsw` filter
79
+ # at the cost of increasing the possibility of filtering benign images
80
+ adjustment = 0.0
81
+
82
+ for concept_idx in range(len(special_cos_dist[0])):
83
+ concept_cos = special_cos_dist[i][concept_idx]
84
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
85
+ result_img["special_scores"][concept_idx] = round(
86
+ concept_cos - concept_threshold + adjustment, 3
87
+ )
88
+ if result_img["special_scores"][concept_idx] > 0:
89
+ result_img["special_care"].append(
90
+ {concept_idx, result_img["special_scores"][concept_idx]}
91
+ )
92
+ adjustment = 0.01
93
+
94
+ for concept_idx in range(len(cos_dist[0])):
95
+ concept_cos = cos_dist[i][concept_idx]
96
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
97
+ result_img["concept_scores"][concept_idx] = round(
98
+ concept_cos - concept_threshold + adjustment, 3
99
+ )
100
+ if result_img["concept_scores"][concept_idx] > 0:
101
+ result_img["bad_concepts"].append(concept_idx)
102
+
103
+ result.append(result_img)
104
+
105
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
106
+
107
+ return has_nsfw_concepts
108
+
109
+ @torch.no_grad()
110
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
111
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
112
+ image_embeds = self.visual_projection(pooled_output)
113
+
114
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
115
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
116
+
117
+ # increase this value to create a stronger `nsfw` filter
118
+ # at the cost of increasing the possibility of filtering benign images
119
+ adjustment = 0.0
120
+
121
+ special_scores = (
122
+ special_cos_dist - self.special_care_embeds_weights + adjustment
123
+ )
124
+ # special_scores = special_scores.round(decimals=3)
125
+ special_care = torch.any(special_scores > 0, dim=1)
126
+ special_adjustment = special_care * 0.01
127
+ special_adjustment = special_adjustment.unsqueeze(1).expand(
128
+ -1, cos_dist.shape[1]
129
+ )
130
+
131
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
132
+ # concept_scores = concept_scores.round(decimals=3)
133
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
134
+
135
+ images[has_nsfw_concepts] = 0.0 # black image
136
+
137
+ return images, has_nsfw_concepts
style.css ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .gradio-container {
2
+ max-width: 690px! important;
3
+ }
4
+
5
+ #share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
6
+ div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
7
+ #share-btn-container:hover {background-color: #060606}
8
+ #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
9
+ #share-btn * {all: unset}
10
+ #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
11
+ #share-btn-container .wrap {display: none !important}
12
+ #share-btn-container.hidden {display: none!important}