mkshing commited on
Commit
366401e
·
0 Parent(s):

initial commit

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +157 -0
  4. evosdxl_jp_v1.py +204 -0
  5. requirements.txt +6 -0
  6. safety_checker.py +137 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EvoSDXL JP
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import spaces
10
+ import torch
11
+ from PIL import Image
12
+ from evosdxl_jp_v1 import load_evosdxl_jp
13
+
14
+ DESCRIPTION = """# 🐟 EvoSDXL-JP
15
+ 🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evosdxl-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
16
+
17
+ [EvoSDXL-JP](https://huggingface.co/SakanaAI/EvoSDXL-JP-v1)は[Sakana AI](https://sakana.ai/)が教育目的で開発した日本特化の高速な画像生成モデルです。
18
+ 入力した日本語プロンプトに沿った画像を生成することができます。より詳しくは、上記のブログをご参照ください。
19
+ """
20
+ if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
22
+
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ NUM_IMAGES_PER_PROMPT = 1
29
+ ENABLE_CPU_OFFLOAD = False
30
+ USE_TORCH_COMPILE = False
31
+ SAFETY_CHECKER = True
32
+ DEVELOP_MODE = True
33
+ if SAFETY_CHECKER:
34
+ from safety_checker import StableDiffusionSafetyChecker
35
+ from transformers import CLIPFeatureExtractor
36
+
37
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
38
+ "CompVis/stable-diffusion-safety-checker"
39
+ ).to(device)
40
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
41
+ "openai/clip-vit-base-patch32"
42
+ )
43
+
44
+ def check_nsfw_images(
45
+ images: list[Image.Image],
46
+ ) -> tuple[list[Image.Image], list[bool]]:
47
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
48
+ has_nsfw_concepts = safety_checker(
49
+ images=[images],
50
+ clip_input=safety_checker_input.pixel_values.to(device)
51
+ )
52
+
53
+ return images, has_nsfw_concepts
54
+
55
+
56
+ pipe = load_evosdxl_jp("cpu").to("cuda")
57
+
58
+ def show_warning(warning_text: str) -> gr.Blocks:
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown(warning_text)
61
+ return demo
62
+
63
+ def save_image(img):
64
+ unique_name = str(uuid.uuid4()) + ".png"
65
+ img.save(unique_name)
66
+ return unique_name
67
+
68
+
69
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
70
+ if randomize_seed:
71
+ seed = random.randint(0, MAX_SEED)
72
+ return seed
73
+
74
+
75
+ @spaces.GPU
76
+ def generate(
77
+ prompt: str,
78
+ seed: int = 0,
79
+ randomize_seed: bool = False,
80
+ progress=gr.Progress(track_tqdm=True),
81
+ ):
82
+ pipe.to(device)
83
+ seed = int(randomize_seed_fn(seed, randomize_seed))
84
+ generator = torch.Generator().manual_seed(seed)
85
+
86
+ images = pipe(
87
+ prompt=prompt,
88
+ width=1024,
89
+ height=1024,
90
+ guidance_scale=0,
91
+ num_inference_steps=4,
92
+ generator=generator,
93
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
94
+ output_type="pil",
95
+ ).images
96
+
97
+ if SAFETY_CHECKER:
98
+ images, has_nsfw_concepts = check_nsfw_images(images)
99
+ if any(has_nsfw_concepts):
100
+ gr.Warning("NSFW content detected.")
101
+ return Image.new("RGB", (512, 512)), seed
102
+ return images[0], seed
103
+
104
+
105
+ examples = [
106
+ "柴犬が草原に立つ、幻想的な空、アート、最高品質の写真、ピントが当たってる"
107
+ ]
108
+
109
+ css = '''
110
+ .gradio-container{max-width: 690px !important}
111
+ h1{text-align:center}
112
+ '''
113
+ with gr.Blocks(css=css) as demo:
114
+ gr.Markdown(DESCRIPTION)
115
+ with gr.Group():
116
+ with gr.Row():
117
+ prompt = gr.Textbox(placeholder="日本語でプロンプトを入力してください。", show_label=False, scale=8)
118
+ submit = gr.Button(scale=0)
119
+ result = gr.Image(label="EvoSDXL-JPからの生成結果", show_label=False)
120
+ with gr.Accordion("詳細設定", open=False):
121
+ seed = gr.Slider(
122
+ label="シード値",
123
+ minimum=0,
124
+ maximum=MAX_SEED,
125
+ step=1,
126
+ value=0,
127
+ )
128
+ randomize_seed = gr.Checkbox(label="ランダムにシード値を決定", value=True)
129
+
130
+ # gr.Examples(
131
+ # examples=examples,
132
+ # inputs=prompt,
133
+ # outputs=[result, seed],
134
+ # fn=generate,
135
+ # # cache_examples=CACHE_EXAMPLES,
136
+ # )
137
+
138
+ gr.on(
139
+ triggers=[
140
+ prompt.submit,
141
+ submit.click,
142
+ ],
143
+ fn=generate,
144
+ inputs=[
145
+ prompt,
146
+ seed,
147
+ randomize_seed,
148
+ ],
149
+ outputs=[result, seed],
150
+ api_name="run",
151
+ )
152
+ gr.Markdown("""⚠️ 本モデルは実験段階のプロトタイプであり、教育および研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。
153
+ 本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。
154
+ Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
155
+ 利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。""")
156
+
157
+ demo.queue().launch()
evosdxl_jp_v1.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Union
3
+ from tqdm import tqdm
4
+ import torch
5
+ import safetensors
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
8
+ from diffusers import (
9
+ StableDiffusionXLPipeline,
10
+ UNet2DConditionModel,
11
+ EulerDiscreteScheduler,
12
+ )
13
+ from diffusers.loaders import LoraLoaderMixin
14
+
15
+ SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
17
+ L_REPO = "ByteDance/SDXL-Lightning"
18
+
19
+
20
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
21
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
22
+ if file_extension == "safetensors":
23
+ return safetensors.torch.load_file(checkpoint_file, device=device)
24
+ else:
25
+ return torch.load(checkpoint_file, map_location=device)
26
+
27
+
28
+ def load_from_pretrained(
29
+ repo_id,
30
+ filename="diffusion_pytorch_model.fp16.safetensors",
31
+ subfolder="unet",
32
+ device="cuda",
33
+ ) -> Dict[str, torch.Tensor]:
34
+ return load_state_dict(
35
+ hf_hub_download(
36
+ repo_id=repo_id,
37
+ filename=filename,
38
+ subfolder=subfolder,
39
+ ),
40
+ device=device,
41
+ )
42
+
43
+
44
+ def reshape_weight_task_tensors(task_tensors, weights):
45
+ """
46
+ Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions.
47
+
48
+ Args:
49
+ task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
50
+ weights (`torch.Tensor`): The tensor to be reshaped.
51
+
52
+ Returns:
53
+ `torch.Tensor`: The reshaped tensor.
54
+ """
55
+ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
56
+ weights = weights.view(new_shape)
57
+ return weights
58
+
59
+
60
+ def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Merge the task tensors using `linear`.
63
+
64
+ Args:
65
+ task_tensors(`List[torch.Tensor]`):The task tensors to merge.
66
+ weights (`torch.Tensor`):The weights of the task tensors.
67
+
68
+ Returns:
69
+ `torch.Tensor`: The merged tensor.
70
+ """
71
+ task_tensors = torch.stack(task_tensors, dim=0)
72
+ # weighted task tensors
73
+ weights = reshape_weight_task_tensors(task_tensors, weights)
74
+ weighted_task_tensors = task_tensors * weights
75
+ mixed_task_tensors = weighted_task_tensors.sum(dim=0)
76
+ return mixed_task_tensors
77
+
78
+
79
+ def merge_models(
80
+ task_tensors,
81
+ weights,
82
+ ):
83
+ keys = list(task_tensors[0].keys())
84
+ weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
85
+ state_dict = {}
86
+ for key in tqdm(keys, desc="Merging"):
87
+ w_list = []
88
+ for i, sd in enumerate(task_tensors):
89
+ w = sd.pop(key)
90
+ w_list.append(w)
91
+ new_w = linear(task_tensors=w_list, weights=weights)
92
+ state_dict[key] = new_w
93
+ return state_dict
94
+
95
+
96
+ def split_conv_attn(weights):
97
+ attn_tensors = {}
98
+ conv_tensors = {}
99
+ for key in list(weights.keys()):
100
+ if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
101
+ attn_tensors[key] = weights.pop(key)
102
+ else:
103
+ conv_tensors[key] = weights.pop(key)
104
+ return {"conv": conv_tensors, "attn": attn_tensors}
105
+
106
+
107
+ def load_evosdxl_jp(device="cuda") -> StableDiffusionXLPipeline:
108
+ sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
109
+ dpo_weights = split_conv_attn(
110
+ load_from_pretrained(
111
+ "mhdang/dpo-sdxl-text2image-v1",
112
+ "diffusion_pytorch_model.safetensors",
113
+ device=device,
114
+ )
115
+ )
116
+ jn_weights = split_conv_attn(
117
+ load_from_pretrained("RunDiffusion/Juggernaut-XL-v9", device=device)
118
+ )
119
+ jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
120
+ tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
121
+ new_conv = merge_models(
122
+ [sd["conv"] for sd in tensors],
123
+ [
124
+ 0.15928833971605916,
125
+ 0.1032449268871776,
126
+ 0.6503217149752791,
127
+ 0.08714501842148402,
128
+ ],
129
+ )
130
+ new_attn = merge_models(
131
+ [sd["attn"] for sd in tensors],
132
+ [
133
+ 0.1877279276437178,
134
+ 0.20014114603909822,
135
+ 0.3922685507065275,
136
+ 0.2198623756106564,
137
+ ],
138
+ )
139
+ del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
140
+ torch.cuda.empty_cache()
141
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
142
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
143
+ unet.load_state_dict({**new_conv, **new_attn})
144
+ state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(
145
+ L_REPO, weight_name="sdxl_lightning_4step_lora.safetensors"
146
+ )
147
+ LoraLoaderMixin.load_lora_into_unet(state_dict, network_alphas, unet)
148
+ unet.fuse_lora(lora_scale=3.224682864579401)
149
+ new_weights = split_conv_attn(unet.state_dict())
150
+ l_weights = split_conv_attn(
151
+ load_from_pretrained(
152
+ L_REPO,
153
+ "sdxl_lightning_4step_unet.safetensors",
154
+ subfolder=None,
155
+ device=device,
156
+ )
157
+ )
158
+ jnl_weights = split_conv_attn(
159
+ load_from_pretrained(
160
+ "RunDiffusion/Juggernaut-XL-Lightning",
161
+ "diffusion_pytorch_model.bin",
162
+ device=device,
163
+ )
164
+ )
165
+ tensors = [l_weights, jnl_weights, new_weights]
166
+ new_conv = merge_models(
167
+ [sd["conv"] for sd in tensors],
168
+ [0.47222002022088533, 0.48419531030361584, 0.04358466947549889],
169
+ )
170
+ new_attn = merge_models(
171
+ [sd["attn"] for sd in tensors],
172
+ [0.023119324530758375, 0.04924981616469831, 0.9276308593045434],
173
+ )
174
+ new_weights = {**new_conv, **new_attn}
175
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
176
+ unet.load_state_dict({**new_conv, **new_attn})
177
+
178
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
179
+ JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16"
180
+ )
181
+ tokenizer = AutoTokenizer.from_pretrained(
182
+ JSDXL_REPO, subfolder="tokenizer", use_fast=False
183
+ )
184
+
185
+ pipe = StableDiffusionXLPipeline.from_pretrained(
186
+ SDXL_REPO,
187
+ unet=unet,
188
+ text_encoder=text_encoder,
189
+ tokenizer=tokenizer,
190
+ torch_dtype=torch.float16,
191
+ variant="fp16",
192
+ )
193
+ # Ensure sampler uses "trailing" timesteps.
194
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
195
+ pipe.scheduler.config, timestep_spacing="trailing"
196
+ )
197
+ pipe = pipe.to(device, dtype=torch.float16)
198
+ return pipe
199
+
200
+
201
+ if __name__ == "__main__":
202
+ pipe: StableDiffusionXLPipeline = load_evosdxl_jp()
203
+ images = pipe("犬", num_inference_steps=4, guidance_scale=0).images
204
+ images[0].save("out.png")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers==0.26.0
3
+ transformers
4
+ safetensors
5
+ accelerate
6
+ sentencepiece
safety_checker.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
19
+
20
+
21
+ def cosine_distance(image_embeds, text_embeds):
22
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
23
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
24
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
25
+
26
+
27
+ class StableDiffusionSafetyChecker(PreTrainedModel):
28
+ config_class = CLIPConfig
29
+
30
+ _no_split_modules = ["CLIPEncoderLayer"]
31
+
32
+ def __init__(self, config: CLIPConfig):
33
+ super().__init__(config)
34
+
35
+ self.vision_model = CLIPVisionModel(config.vision_config)
36
+ self.visual_projection = nn.Linear(
37
+ config.vision_config.hidden_size, config.projection_dim, bias=False
38
+ )
39
+
40
+ self.concept_embeds = nn.Parameter(
41
+ torch.ones(17, config.projection_dim), requires_grad=False
42
+ )
43
+ self.special_care_embeds = nn.Parameter(
44
+ torch.ones(3, config.projection_dim), requires_grad=False
45
+ )
46
+
47
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
48
+ self.special_care_embeds_weights = nn.Parameter(
49
+ torch.ones(3), requires_grad=False
50
+ )
51
+
52
+ @torch.no_grad()
53
+ def forward(self, clip_input, images):
54
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
55
+ image_embeds = self.visual_projection(pooled_output)
56
+
57
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
58
+ special_cos_dist = (
59
+ cosine_distance(image_embeds, self.special_care_embeds)
60
+ .cpu()
61
+ .float()
62
+ .numpy()
63
+ )
64
+ cos_dist = (
65
+ cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
66
+ )
67
+
68
+ result = []
69
+ batch_size = image_embeds.shape[0]
70
+ for i in range(batch_size):
71
+ result_img = {
72
+ "special_scores": {},
73
+ "special_care": [],
74
+ "concept_scores": {},
75
+ "bad_concepts": [],
76
+ }
77
+
78
+ # increase this value to create a stronger `nfsw` filter
79
+ # at the cost of increasing the possibility of filtering benign images
80
+ adjustment = 0.0
81
+
82
+ for concept_idx in range(len(special_cos_dist[0])):
83
+ concept_cos = special_cos_dist[i][concept_idx]
84
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
85
+ result_img["special_scores"][concept_idx] = round(
86
+ concept_cos - concept_threshold + adjustment, 3
87
+ )
88
+ if result_img["special_scores"][concept_idx] > 0:
89
+ result_img["special_care"].append(
90
+ {concept_idx, result_img["special_scores"][concept_idx]}
91
+ )
92
+ adjustment = 0.01
93
+
94
+ for concept_idx in range(len(cos_dist[0])):
95
+ concept_cos = cos_dist[i][concept_idx]
96
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
97
+ result_img["concept_scores"][concept_idx] = round(
98
+ concept_cos - concept_threshold + adjustment, 3
99
+ )
100
+ if result_img["concept_scores"][concept_idx] > 0:
101
+ result_img["bad_concepts"].append(concept_idx)
102
+
103
+ result.append(result_img)
104
+
105
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
106
+
107
+ return has_nsfw_concepts
108
+
109
+ @torch.no_grad()
110
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
111
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
112
+ image_embeds = self.visual_projection(pooled_output)
113
+
114
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
115
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
116
+
117
+ # increase this value to create a stronger `nsfw` filter
118
+ # at the cost of increasing the possibility of filtering benign images
119
+ adjustment = 0.0
120
+
121
+ special_scores = (
122
+ special_cos_dist - self.special_care_embeds_weights + adjustment
123
+ )
124
+ # special_scores = special_scores.round(decimals=3)
125
+ special_care = torch.any(special_scores > 0, dim=1)
126
+ special_adjustment = special_care * 0.01
127
+ special_adjustment = special_adjustment.unsqueeze(1).expand(
128
+ -1, cos_dist.shape[1]
129
+ )
130
+
131
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
132
+ # concept_scores = concept_scores.round(decimals=3)
133
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
134
+
135
+ images[has_nsfw_concepts] = 0.0 # black image
136
+
137
+ return images, has_nsfw_concepts