cocktailpeanut 6Morpheus6 commited on
Commit
0542efe
·
verified ·
1 Parent(s): 3902e7a
Files changed (1) hide show
  1. app.py +53 -30
app.py CHANGED
@@ -14,26 +14,44 @@ from scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInj
14
  from transformers import AutoProcessor, BlipForConditionalGeneration
15
  from share_btn import community_icon_html, loading_icon_html, share_js
16
 
17
- # load pipelines
18
- # sd_model_id = "runwayml/stable-diffusion-v1-5"
19
- sd_model_id = "stabilityai/stable-diffusion-2-1-base"
20
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
- torch_dtype = torch.float16
23
- if torch.cuda.is_available():
24
- device = "cuda"
25
- elif torch.backends.mps.is_available():
26
- device = "mps"
27
- torch_dtype = torch.float32
28
- else:
29
- device = "cpu"
30
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype)
31
- pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,vae=vae,torch_dtype=torch_dtype, safety_checker=None, requires_safety_checker=False).to(device)
32
- pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
33
- , algorithm_type="sde-dpmsolver++", solver_order=2)
34
-
35
- blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
36
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch_dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  ## IMAGE CPATIONING ##
39
  def caption_image(input_image):
@@ -860,14 +878,19 @@ with gr.Blocks(css="style.css") as demo:
860
  clear_button.click(lambda: clear_components_output_vals, outputs = clear_components)
861
 
862
  reconstruct_button.click(lambda: ddpm_edited_image.update(visible=True), outputs=[ddpm_edited_image]).then(fn = reconstruct,
863
- inputs = [tar_prompt,
864
- image_caption,
865
- tar_cfg_scale,
866
- skip,
867
- wts, zs,
868
- do_reconstruction,
869
- reconstruction,
870
- reconstruct_button],
 
 
 
 
 
871
  outputs = [ddpm_edited_image,reconstruction, ddpm_edited_image, do_reconstruction, reconstruct_button])
872
 
873
  randomize_seed.change(
@@ -905,8 +928,8 @@ with gr.Blocks(css="style.css") as demo:
905
  threshold_3,
906
  seed
907
  ],
908
- outputs=[share_btn_container, box1, concept_1, guidnace_scale_1,neg_guidance_1, row1, row2,box2, concept_2, guidnace_scale_2,neg_guidance_2,row2, row3,sega_concepts_counter],
909
- cache_examples=True
910
  )
911
 
912
  demo.queue(default_concurrency_limit=1)
 
14
  from transformers import AutoProcessor, BlipForConditionalGeneration
15
  from share_btn import community_icon_html, loading_icon_html, share_js
16
 
17
+ def load_models():
18
+ sd_model_id = "stabilityai/stable-diffusion-2-1-base"
19
+
20
+ torch_dtype = torch.float16
21
+ if torch.cuda.is_available():
22
+ device = "cuda"
23
+ elif torch.backends.mps.is_available():
24
+ device = "mps"
25
+ torch_dtype = torch.float32
26
+ else:
27
+ device = "cpu"
28
+
29
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype)
30
+
31
+ pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(
32
+ sd_model_id,
33
+ vae=vae,
34
+ torch_dtype=torch_dtype,
35
+ safety_checker=None,
36
+ requires_safety_checker=False
37
+ ).to(device)
38
+
39
+ pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(
40
+ sd_model_id,
41
+ subfolder="scheduler",
42
+ algorithm_type="sde-dpmsolver++",
43
+ solver_order=2
44
+ )
45
+
46
+ blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
47
+ blip_model = BlipForConditionalGeneration.from_pretrained(
48
+ "Salesforce/blip-image-captioning-base",
49
+ torch_dtype=torch_dtype
50
+ ).to(device)
51
+
52
+ return vae, pipe, blip_processor, blip_model, device, torch_dtype
53
+
54
+ vae, pipe, blip_processor, blip_model, device, torch_dtype = load_models()
55
 
56
  ## IMAGE CPATIONING ##
57
  def caption_image(input_image):
 
878
  clear_button.click(lambda: clear_components_output_vals, outputs = clear_components)
879
 
880
  reconstruct_button.click(lambda: ddpm_edited_image.update(visible=True), outputs=[ddpm_edited_image]).then(fn = reconstruct,
881
+ inputs = [
882
+ tar_prompt,
883
+ image_caption,
884
+ tar_cfg_scale,
885
+ skip,
886
+ wts,
887
+ zs,
888
+ attention_store,
889
+ text_cross_attention_maps,
890
+ do_reconstruction,
891
+ reconstruction,
892
+ reconstruct_button
893
+ ],
894
  outputs = [ddpm_edited_image,reconstruction, ddpm_edited_image, do_reconstruction, reconstruct_button])
895
 
896
  randomize_seed.change(
 
928
  threshold_3,
929
  seed
930
  ],
931
+ outputs=[share_btn_container, box1, concept_1, guidnace_scale_1,neg_guidance_1, row1, row2, box2, concept_2, guidnace_scale_2,neg_guidance_2,row2, row3,sega_concepts_counter],
932
+ cache_examples=False
933
  )
934
 
935
  demo.queue(default_concurrency_limit=1)