Krebzonide commited on
Commit
63aa355
·
1 Parent(s): 7588cca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -13,7 +13,6 @@ model_list = ["stabilityai/stable-diffusion-xl-base-1.0",
13
  model_url_list = ["stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
14
  "Krebzonide/Colossus_Project_XL/blob/main/colossusProjectXLSFW_v202BakedVAE.safetensors",
15
  "Krebzonide/Sevenof9_v3_sdxl/blob/main/nsfwSevenof9V3_nsfwSevenof9V3.safetensors"]
16
- current_model = -1
17
 
18
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
19
  pipe = None
@@ -50,17 +49,11 @@ def generate(prompt, neg_prompt, samp_steps, guide_scale, batch_size, seed, heig
50
  ).images
51
  return [(img, f"Image {i+1}") for i, img in enumerate(images)]
52
 
53
- def set_base_model(base_model_id):
54
  global pipe
55
- global current_model
56
  global model_list
57
  global model_url_list
58
  new_model = model_list.index(base_model_id)
59
- if new_model == current_model:
60
- return pipe
61
- del pipe
62
- torch.cuda.empty_cache()
63
- gc.collect()
64
  model_url = "https://huggingface.co/" + model_url_list[new_model]
65
  pipe = StableDiffusionXLPipeline.from_single_file(
66
  model_url,
@@ -71,8 +64,10 @@ def set_base_model(base_model_id):
71
  use_auth_token="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"
72
  )
73
  pipe.to("cuda")
74
- current_model = new_model
 
75
  return pipe
 
76
 
77
  with gr.Blocks(css=css) as demo:
78
  with gr.Column():
@@ -88,13 +83,13 @@ with gr.Blocks(css=css) as demo:
88
  height = gr.Slider(label="Height", value=1024, minimum=512, maximum=2048, step=16)
89
  width = gr.Slider(label="Width", value=1024, minimum=512, maximum=2048, step=16)
90
  gallery = gr.Gallery(label="Generated images", height=800)
 
 
 
 
91
  with gr.Row():
92
- model_id = gr.Dropdown(model_list, label="model")
93
  change_model_btn = gr.Button("Update Model", elem_classes="btn-green")
 
94
 
95
- submit_btn.click(generate, [prompt, negative_prompt, samp_steps, guide_scale, batch_size, seed, height, width], [gallery], queue=True)
96
- change_model_btn.click(set_base_model, [model_id], queue = False)
97
-
98
- pipe = set_base_model(model_list[0])
99
- demo.queue(1)
100
- demo.launch(debug=True)
 
13
  model_url_list = ["stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
14
  "Krebzonide/Colossus_Project_XL/blob/main/colossusProjectXLSFW_v202BakedVAE.safetensors",
15
  "Krebzonide/Sevenof9_v3_sdxl/blob/main/nsfwSevenof9V3_nsfwSevenof9V3.safetensors"]
 
16
 
17
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
18
  pipe = None
 
49
  ).images
50
  return [(img, f"Image {i+1}") for i, img in enumerate(images)]
51
 
52
+ def set_base_model(base_model_id, progress=gr.Progress(track_tqdm=True)):
53
  global pipe
 
54
  global model_list
55
  global model_url_list
56
  new_model = model_list.index(base_model_id)
 
 
 
 
 
57
  model_url = "https://huggingface.co/" + model_url_list[new_model]
58
  pipe = StableDiffusionXLPipeline.from_single_file(
59
  model_url,
 
64
  use_auth_token="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"
65
  )
66
  pipe.to("cuda")
67
+ intro.close()
68
+ demo.launch(debug=True)
69
  return pipe
70
+
71
 
72
  with gr.Blocks(css=css) as demo:
73
  with gr.Column():
 
83
  height = gr.Slider(label="Height", value=1024, minimum=512, maximum=2048, step=16)
84
  width = gr.Slider(label="Width", value=1024, minimum=512, maximum=2048, step=16)
85
  gallery = gr.Gallery(label="Generated images", height=800)
86
+ submit_btn.click(generate, [prompt, negative_prompt, samp_steps, guide_scale, batch_size, seed, height, width], [gallery], queue=True)
87
+
88
+ with gr.Blocks(css=css) as intro:
89
+ with gr.Column():
90
  with gr.Row():
91
+ model_id = gr.Dropdown(model_list, label="model", value="stabilityai/stable-diffusion-xl-base-1.0")
92
  change_model_btn = gr.Button("Update Model", elem_classes="btn-green")
93
+ change_model_btn.click(set_base_model, [model_id])
94
 
95
+ intro.launch(debug=True)