jiuface commited on
Commit
625a0c1
·
verified ·
1 Parent(s): 6840742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -50
app.py CHANGED
@@ -41,10 +41,6 @@ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtyp
41
  txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
42
  txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
43
 
44
- # img2img model
45
- img2img_pipe = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=txt2img_pipe.transformer, text_encoder=txt2img_pipe.text_encoder, tokenizer=txt2img_pipe.tokenizer, text_encoder_2=txt2img_pipe.text_encoder_2, tokenizer_2=txt2img_pipe.tokenizer_2, torch_dtype=dtype)
46
- img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
47
-
48
 
49
  MAX_SEED = 2**32 - 1
50
 
@@ -118,15 +114,7 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
118
  img2img_model = False
119
  orginal_image = None
120
  print(device)
121
- if image_url and image_url != "":
122
- print("img2img")
123
- orginal_image = load_image(image_url).to(device)
124
- img2img_model = True
125
- img2img_pipe.to(device)
126
- else:
127
- print("txt2img")
128
- txt2img_pipe.to(device)
129
-
130
  # Set random seed for reproducibility
131
  if randomize_seed:
132
  with calculateDuration("Set random seed"):
@@ -135,9 +123,8 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
135
  # Load LoRA weights
136
  gr.Info("Start to load LoRA ...")
137
  with calculateDuration("Unloading LoRA"):
138
- # img2img_pipe.unload_lora_weights()
139
  txt2img_pipe.unload_lora_weights()
140
-
141
  lora_configs = None
142
  adapter_names = []
143
  lora_names = []
@@ -165,19 +152,13 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
165
  adapter_weights.append(adapter_weight)
166
  if lora_repo and weights and adapter_name:
167
  try:
168
- if img2img_model:
169
- img2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
170
- else:
171
- txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
172
  except:
173
  print("load lora error")
174
 
175
  # set lora weights
176
  if len(lora_names) > 0:
177
- if img2img_model:
178
- img2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
179
- else:
180
- txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
181
 
182
  # Generate image
183
  error_message = ""
@@ -185,36 +166,20 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
185
  gr.Info("Start to generate images ...")
186
  with calculateDuration(f"Make a new generator: {seed}"):
187
  generator = torch.Generator(device=device).manual_seed(seed)
188
-
189
  with calculateDuration("Generating image"):
190
  # Generate image
191
  joint_attention_kwargs = {"scale": 1}
192
-
193
- if orginal_image:
194
- img2img_pipe.to(device)
195
- final_image = img2img_pipe(
196
- prompt=prompt,
197
- image=orginal_image,
198
- strength=image_strength,
199
- num_inference_steps=steps,
200
- guidance_scale=cfg_scale,
201
- width=width,
202
- height=height,
203
- generator=generator,
204
- joint_attention_kwargs=joint_attention_kwargs
205
- ).images[0]
206
- else:
207
- txt2img_pipe.to(device)
208
- final_image = txt2img_pipe(
209
- prompt=prompt,
210
- num_inference_steps=steps,
211
- guidance_scale=cfg_scale,
212
- width=width,
213
- height=height,
214
- max_sequence_length=512,
215
- generator=generator,
216
- joint_attention_kwargs=joint_attention_kwargs
217
- ).images[0]
218
 
219
  except Exception as e:
220
  error_message = str(e)
 
41
  txt2img_pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
42
  txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
43
 
 
 
 
 
44
 
45
  MAX_SEED = 2**32 - 1
46
 
 
114
  img2img_model = False
115
  orginal_image = None
116
  print(device)
117
+
 
 
 
 
 
 
 
 
118
  # Set random seed for reproducibility
119
  if randomize_seed:
120
  with calculateDuration("Set random seed"):
 
123
  # Load LoRA weights
124
  gr.Info("Start to load LoRA ...")
125
  with calculateDuration("Unloading LoRA"):
 
126
  txt2img_pipe.unload_lora_weights()
127
+ print(device)
128
  lora_configs = None
129
  adapter_names = []
130
  lora_names = []
 
152
  adapter_weights.append(adapter_weight)
153
  if lora_repo and weights and adapter_name:
154
  try:
155
+ txt2img_pipe.load_lora_weights(lora_repo, weight_name=weights, low_cpu_mem_usage=True, adapter_name=lora_name)
 
 
 
156
  except:
157
  print("load lora error")
158
 
159
  # set lora weights
160
  if len(lora_names) > 0:
161
+ txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
 
 
 
162
 
163
  # Generate image
164
  error_message = ""
 
166
  gr.Info("Start to generate images ...")
167
  with calculateDuration(f"Make a new generator: {seed}"):
168
  generator = torch.Generator(device=device).manual_seed(seed)
169
+ print(device)
170
  with calculateDuration("Generating image"):
171
  # Generate image
172
  joint_attention_kwargs = {"scale": 1}
173
+ final_image = txt2img_pipe(
174
+ prompt=prompt,
175
+ num_inference_steps=steps,
176
+ guidance_scale=cfg_scale,
177
+ width=width,
178
+ height=height,
179
+ max_sequence_length=512,
180
+ generator=generator,
181
+ joint_attention_kwargs=joint_attention_kwargs
182
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  except Exception as e:
185
  error_message = str(e)