Krebzonide commited on
Commit
a02443f
·
1 Parent(s): fae7b1e

finished code to change model

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -13,6 +13,7 @@ model_list = ["stabilityai/stable-diffusion-xl-base-1.0",
13
  model_url_list = ["stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
14
  "Krebzonide/Colossus_Project_XL/blob/main/colossusProjectXLSFW_v202BakedVAE.safetensors",
15
  "Krebzonide/Sevenof9_v3_sdxl/blob/main/nsfwSevenof9V3_nsfwSevenof9V3.safetensors"]
 
16
 
17
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
18
  pipe = None
@@ -51,6 +52,12 @@ def generate(prompt, neg_prompt, samp_steps, guide_scale, batch_size, seed, heig
51
 
52
  def set_base_model(base_model_id):
53
  global pipe
 
 
 
 
 
 
54
  del pipe
55
  torch.cuda.empty_cache()
56
  gc.collect()
@@ -63,8 +70,8 @@ def set_base_model(base_model_id):
63
  use_auth_token="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"
64
  )
65
  pipe.to("cuda")
 
66
  return pipe
67
-
68
 
69
  with gr.Blocks(css=css) as demo:
70
  with gr.Column():
@@ -85,8 +92,8 @@ with gr.Blocks(css=css) as demo:
85
  change_model_btn = gr.Button("Update Model", elem_classes="btn-green")
86
 
87
  submit_btn.click(generate, [prompt, negative_prompt, samp_steps, guide_scale, batch_size, seed, height, width], [gallery], queue=True)
88
- change_model_btn.click(set_base_model, [model_id], queue = false)
89
 
90
- pipe = set_base_model(model_list[0])
91
  demo.queue(1)
92
  demo.launch(debug=True)
 
13
  model_url_list = ["stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
14
  "Krebzonide/Colossus_Project_XL/blob/main/colossusProjectXLSFW_v202BakedVAE.safetensors",
15
  "Krebzonide/Sevenof9_v3_sdxl/blob/main/nsfwSevenof9V3_nsfwSevenof9V3.safetensors"]
16
+ current_model = 0
17
 
18
  vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
19
  pipe = None
 
52
 
53
  def set_base_model(base_model_id):
54
  global pipe
55
+ global current_model
56
+ global model_list
57
+ global model_url_list
58
+ new_model = model_list.index(base_model_id)
59
+ if new_model == current_model:
60
+ return pipe
61
  del pipe
62
  torch.cuda.empty_cache()
63
  gc.collect()
 
70
  use_auth_token="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"
71
  )
72
  pipe.to("cuda")
73
+ current_model = new_model
74
  return pipe
 
75
 
76
  with gr.Blocks(css=css) as demo:
77
  with gr.Column():
 
92
  change_model_btn = gr.Button("Update Model", elem_classes="btn-green")
93
 
94
  submit_btn.click(generate, [prompt, negative_prompt, samp_steps, guide_scale, batch_size, seed, height, width], [gallery], queue=True)
95
+ change_model_btn.click(set_base_model, [model_id], queue = False)
96
 
97
+ pipe = set_base_model(model_list[current_model])
98
  demo.queue(1)
99
  demo.launch(debug=True)