WonwoongCho commited on
Commit
dff3f07
·
1 Parent(s): 20bada1

add torchvision to requirement

Browse files
Files changed (2) hide show
  1. app.py +41 -42
  2. requirements.txt +1 -2
app.py CHANGED
@@ -12,10 +12,49 @@ from src.utils_sample import set_seed, resize_and_add_margin
12
  import os
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU
16
- def process_image_and_text(image, scale, seed, text, pipe):
17
  set_seed(seed)
 
 
 
 
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # image = Image.open(img_path).convert('RGB')
20
  image = resize_and_add_margin(image, target_size=512)
21
 
@@ -108,16 +147,6 @@ header = """
108
 
109
  def create_app():
110
 
111
- dtype = torch.bfloat16
112
- token = os.environ.get("HF_TOKEN")
113
-
114
- pipe = FluxPipeline.from_pretrained(
115
- "black-forest-labs/FLUX.1-dev",
116
- torch_dtype=dtype,
117
- use_auth_token=token
118
- )
119
- pipe = pipe.to("cuda")
120
-
121
  with gr.Blocks() as app:
122
  gr.Markdown(header, elem_id="header")
123
  with gr.Row(equal_height=False):
@@ -147,39 +176,9 @@ def create_app():
147
  label="Examples",
148
  )
149
 
150
- print("execution_device 1", pipe._execution_device)
151
- blended_attn_procs = {}
152
- for name, _ in pipe.transformer.attn_processors.items():
153
- if "single" in name:
154
- blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
155
- else:
156
- blended_attn_procs[name] = pipe.transformer.attn_processors[name]
157
-
158
- pipe.transformer.set_attn_processor(blended_attn_procs)
159
- pipe = pipe.to(dtype)
160
- pipe = pipe.to("cuda")
161
- print("execution_device 2", pipe._execution_device)
162
-
163
- model_path = hf_hub_download(
164
- repo_id="WonwoongCho/IT-Blender",
165
- filename="FLUX/it-blender.bin",
166
- token=token
167
- )
168
- pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
169
-
170
- key_changed_blended_attn_weights = {}
171
- for key, value in pretrained_blended_attn_weights.items():
172
- block_idx = int(key.split(".")[0]) - 21
173
- k_or_v = key.split("_")[2]
174
- changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
175
- key_changed_blended_attn_weights[changed_key] = value.to(dtype)
176
-
177
- missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
178
-
179
-
180
  submit_btn.click(
181
  fn=process_image_and_text,
182
- inputs=[original_image, scale, seed, text, pipe],
183
  outputs=output_image,
184
  )
185
 
 
12
  import os
13
 
14
 
15
+ dtype = torch.bfloat16
16
+ token = os.environ.get("HF_TOKEN")
17
+
18
+ pipe = None
19
+ pipe = FluxPipeline.from_pretrained(
20
+ "black-forest-labs/FLUX.1-dev",
21
+ torch_dtype=dtype,
22
+ use_auth_token=token
23
+ )
24
+ pipe = pipe.to("cuda")
25
+
26
  @spaces.GPU
27
+ def process_image_and_text(image, scale, seed, text):
28
  set_seed(seed)
29
+ print("execution_device 1", pipe._execution_device)
30
+ blended_attn_procs = {}
31
+ for name, _ in pipe.transformer.attn_processors.items():
32
+ if "single" in name:
33
+ blended_attn_procs[name] = FluxBlendedAttnProcessor2_0(3072, ba_scale=scale, num_ref=1)
34
+ else:
35
+ blended_attn_procs[name] = pipe.transformer.attn_processors[name]
36
 
37
+ pipe.transformer.set_attn_processor(blended_attn_procs)
38
+ pipe = pipe.to(dtype)
39
+ pipe = pipe.to("cuda")
40
+ print("execution_device 2", pipe._execution_device)
41
+
42
+ model_path = hf_hub_download(
43
+ repo_id="WonwoongCho/IT-Blender",
44
+ filename="FLUX/it-blender.bin",
45
+ token=token
46
+ )
47
+ pretrained_blended_attn_weights = torch.load(model_path, map_location=pipe._execution_device)
48
+
49
+ key_changed_blended_attn_weights = {}
50
+ for key, value in pretrained_blended_attn_weights.items():
51
+ block_idx = int(key.split(".")[0]) - 21
52
+ k_or_v = key.split("_")[2]
53
+ changed_key = f'single_transformer_blocks.{block_idx}.attn.processor.blended_attention_{k_or_v}_proj.weight'
54
+ key_changed_blended_attn_weights[changed_key] = value.to(dtype)
55
+
56
+ missing_keys, unexpected_keys = pipe.transformer.load_state_dict(key_changed_blended_attn_weights, strict=False)
57
+
58
  # image = Image.open(img_path).convert('RGB')
59
  image = resize_and_add_margin(image, target_size=512)
60
 
 
147
 
148
  def create_app():
149
 
 
 
 
 
 
 
 
 
 
 
150
  with gr.Blocks() as app:
151
  gr.Markdown(header, elem_id="header")
152
  with gr.Row(equal_height=False):
 
176
  label="Examples",
177
  )
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  submit_btn.click(
180
  fn=process_image_and_text,
181
+ inputs=[original_image, scale, seed, text],
182
  outputs=output_image,
183
  )
184
 
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
- torch
2
- torchvision
3
  transformers
4
  protobuf
5
  sentencepiece
6
  accelerate
7
  einops
8
  huggingface_hub
 
9
  git+https://github.com/WonwoongCho/diffusers@main#egg=diffusers
 
 
 
1
  transformers
2
  protobuf
3
  sentencepiece
4
  accelerate
5
  einops
6
  huggingface_hub
7
+ torchvision
8
  git+https://github.com/WonwoongCho/diffusers@main#egg=diffusers