MonsterMMORPG commited on
Commit
7d405ed
·
1 Parent(s): ba42139

Upload web-ui.py

Browse files
Files changed (1) hide show
  1. web-ui.py +71 -47
web-ui.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  import torch
6
  from PIL import Image
7
  from insightface.app import FaceAnalysis
8
- from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
9
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus
10
  import argparse
11
  import random
@@ -13,6 +13,7 @@ from insightface.utils import face_align
13
  from pyngrok import ngrok
14
  import threading
15
  import time
 
16
 
17
  # Argument parser for command line options
18
  parser = argparse.ArgumentParser()
@@ -26,38 +27,45 @@ args = parser.parse_args()
26
  # Add new model names here
27
  static_model_names = [
28
  "SG161222/Realistic_Vision_V6.0_B1_noVAE",
29
- "stablediffusionapi/rev-animated-v122-eol",
30
- "Lykon/DreamShaper",
31
- "stablediffusionapi/toonyou",
32
- "stablediffusionapi/real-cartoon-3d",
33
- "KBlueLeaf/kohaku-v2.1",
34
- "nitrosocke/Ghibli-Diffusion",
35
- "Linaqruf/anything-v3.0",
36
- "jinaai/flat-2d-animerge",
37
- "stablediffusionapi/realcartoon3d",
38
- "stablediffusionapi/disney-pixar-cartoon",
39
- "stablediffusionapi/pastel-mix-stylized-anime",
40
- "stablediffusionapi/anything-v5",
41
  "SG161222/Realistic_Vision_V2.0",
42
  "SG161222/Realistic_Vision_V4.0_noVAE",
43
  "SG161222/Realistic_Vision_V5.1_noVAE",
44
- #r"G:\model\model_diffusers"
 
 
45
  ]
46
 
47
  # Cache for loaded models
48
  model_cache = {}
49
  max_cache_size = args.cache_limit
50
 
51
- def convert_model(checkpoint_path, output_path):
52
  try:
53
- pipe = StableDiffusionPipeline.from_single_file(checkpoint_path)
54
- pipe.save_pretrained(output_path)
 
 
 
 
55
  return f"Model converted and saved to {output_path}"
56
  except Exception as e:
57
  return f"Error: {str(e)}"
58
 
 
59
  # Function to load and cache model
60
- def load_model(model_name):
61
  if model_name in model_cache:
62
  return model_cache[model_name]
63
  print(f"loading model {model_name}")
@@ -76,32 +84,47 @@ def load_model(model_name):
76
  steps_offset=1,
77
  )
78
  vae_model_path = "stabilityai/sd-vae-ft-mse"
 
 
79
  vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
80
 
81
- # Load model based on the selected model name
82
- pipe = StableDiffusionPipeline.from_pretrained(
83
- model_name,
84
- torch_dtype=torch.float16,
85
- scheduler=noise_scheduler,
86
- vae=vae,
87
- feature_extractor=None,
88
- safety_checker=None
89
- ).to(device)
90
-
91
- image_encoder_path = "h94/IP-Adapter/models/image_encoder"
92
- ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sd15.bin"
93
- ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  model_cache[model_name] = ip_model
96
  return ip_model
97
 
98
  # Function to process image and generate output
99
- def generate_image(input_image, positive_prompt, negative_prompt, width, height, model_name, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path):
100
  saved_images = []
101
  if custom_model_path:
102
  model_name = custom_model_path
103
  # Load and prepare the model
104
- ip_model = load_model(model_name)
105
 
106
  # Convert input image to the format expected by the model
107
  input_image = input_image.convert("RGB")
@@ -132,6 +155,7 @@ def generate_image(input_image, positive_prompt, negative_prompt, width, height,
132
  s_scale=s_scale,
133
  width=width,
134
  height=height,
 
135
  num_inference_steps=num_inference_steps,
136
  seed=seed,
137
  )
@@ -156,6 +180,7 @@ with gr.Blocks() as demo:
156
  with gr.Row():
157
  width = gr.Number(value=512, label="Width")
158
  height = gr.Number(value=768, label="Height")
 
159
  with gr.Row():
160
  num_inference_steps = gr.Number(value=30, label="Number of Inference Steps", step=1, minimum=10, maximum=100)
161
  seed = gr.Number(value=2023, label="Seed")
@@ -163,7 +188,8 @@ with gr.Blocks() as demo:
163
  with gr.Row():
164
  num_images = gr.Number(value=args.num_images, label="Number of Images to Generate", step=1, minimum=1)
165
  batch_size = gr.Number(value=1, label="Batch Size", step=1)
166
- with gr.Row():
 
167
  enable_shortcut = gr.Checkbox(value=True, label="Enable Shortcut")
168
  s_scale = gr.Number(value=1.0, label="Scale Factor (s_scale)", step=0.1, minimum=0.5, maximum=4.0)
169
  with gr.Row():
@@ -177,39 +203,37 @@ with gr.Blocks() as demo:
177
  output_gallery = gr.Gallery(label="Generated Images")
178
  output_text = gr.Textbox(label="Output Info")
179
  display_seed = gr.Textbox(label="Used Seed", interactive=False)
180
-
181
  with gr.Row():
182
  checkpoint_path_input = gr.Textbox(label="Enter Checkpoint File Path .e.g G:\model\model.safetensors", )
183
  output_path_input = gr.Textbox(label="Enter Output Folder Path, e.g. G:\model\model_diffusers")
184
  convert_btn = gr.Button("Convert Model")
185
 
186
-
187
  generate_btn.click(
188
  generate_image,
189
- inputs=[input_image, positive_prompt, negative_prompt, width, height, model_selector, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path],
190
  outputs=[output_gallery, output_text, display_seed]
191
  )
192
 
193
  convert_btn.click(
194
  convert_model,
195
- inputs=[checkpoint_path_input, output_path_input],
196
  outputs=[gr.Text(label="Conversion Status")],
197
  )
198
-
199
- #sadly doesnt work
200
  def start_ngrok():
201
- print("1")
202
- time.sleep(10) # Delay for 10 seconds to ensure Gradio starts first
203
- print("2")
204
  ngrok.set_auth_token(args.ngrok_token)
205
  public_url = ngrok.connect(port=7860) # Adjust to your Gradio app's port
206
  print(f"ngrok tunnel started at {public_url}")
207
 
208
  if __name__ == "__main__":
209
- #if args.ngrok_token:
210
- # Start ngrok in a daemon thread with a delay
211
- # ngrok_thread = threading.Thread(target=start_ngrok, daemon=False)
212
- # ngrok_thread.start()
213
 
214
  # Launch the Gradio app
215
  demo.launch(share=args.share, inbrowser=True)
 
5
  import torch
6
  from PIL import Image
7
  from insightface.app import FaceAnalysis
8
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline
9
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus
10
  import argparse
11
  import random
 
13
  from pyngrok import ngrok
14
  import threading
15
  import time
16
+ from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
17
 
18
  # Argument parser for command line options
19
  parser = argparse.ArgumentParser()
 
27
  # Add new model names here
28
  static_model_names = [
29
  "SG161222/Realistic_Vision_V6.0_B1_noVAE",
30
+ "stablediffusionapi/rev-animated-v122-eol",
31
+ "Lykon/DreamShaper",
32
+ "stablediffusionapi/toonyou",
33
+ "stablediffusionapi/real-cartoon-3d",
34
+ "KBlueLeaf/kohaku-v2.1",
35
+ "nitrosocke/Ghibli-Diffusion",
36
+ "Linaqruf/anything-v3.0",
37
+ "jinaai/flat-2d-animerge",
38
+ "stablediffusionapi/realcartoon3d",
39
+ "stablediffusionapi/disney-pixar-cartoon",
40
+ "stablediffusionapi/pastel-mix-stylized-anime",
41
+ "stablediffusionapi/anything-v5",
42
  "SG161222/Realistic_Vision_V2.0",
43
  "SG161222/Realistic_Vision_V4.0_noVAE",
44
  "SG161222/Realistic_Vision_V5.1_noVAE",
45
+ "stablediffusionapi/anime-illust-diffusion-xl",
46
+ "stabilityai/stable-diffusion-xl-base-1.0",
47
+ #r"G:\model\model_diffusers"
48
  ]
49
 
50
  # Cache for loaded models
51
  model_cache = {}
52
  max_cache_size = args.cache_limit
53
 
54
+ def convert_model(checkpoint_path, output_path, isSDXL):
55
  try:
56
+ if isSDXL:
57
+ pipe = StableDiffusionXLPipeline.from_single_file(checkpoint_path)
58
+ pipe.save_pretrained(output_path)
59
+ else:
60
+ pipe = StableDiffusionPipeline.from_single_file(checkpoint_path)
61
+ pipe.save_pretrained(output_path)
62
  return f"Model converted and saved to {output_path}"
63
  except Exception as e:
64
  return f"Error: {str(e)}"
65
 
66
+
67
  # Function to load and cache model
68
+ def load_model(model_name, isSDXL):
69
  if model_name in model_cache:
70
  return model_cache[model_name]
71
  print(f"loading model {model_name}")
 
84
  steps_offset=1,
85
  )
86
  vae_model_path = "stabilityai/sd-vae-ft-mse"
87
+ if isSDXL:
88
+ vae_model_path = "stabilityai/sdxl-vae"
89
  vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
90
 
91
+ if isSDXL:
92
+ pipe = StableDiffusionXLPipeline.from_pretrained(
93
+ model_name,
94
+ torch_dtype=torch.float16,
95
+ vae=vae,
96
+ scheduler=noise_scheduler,
97
+ add_watermarker=False,
98
+ ).to(device)
99
+ else:
100
+ # Load model based on the selected model name
101
+ pipe = StableDiffusionPipeline.from_pretrained(
102
+ model_name,
103
+ torch_dtype=torch.float16,
104
+ scheduler=noise_scheduler,
105
+ vae=vae,
106
+ feature_extractor=None,
107
+ safety_checker=None
108
+ ).to(device)
109
+
110
+ if isSDXL:
111
+ ip_ckpt = "adapters/ip-adapter-faceid_sdxl.bin"
112
+ ip_model = IPAdapterFaceIDXL(pipe, ip_ckpt, device)
113
+ else:
114
+ image_encoder_path = "h94/IP-Adapter/models/image_encoder"
115
+ ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sd15.bin"
116
+ ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device)
117
 
118
  model_cache[model_name] = ip_model
119
  return ip_model
120
 
121
  # Function to process image and generate output
122
+ def generate_image(input_image, positive_prompt, negative_prompt, width, height, model_name, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path, isSDXL,cfg):
123
  saved_images = []
124
  if custom_model_path:
125
  model_name = custom_model_path
126
  # Load and prepare the model
127
+ ip_model = load_model(model_name, isSDXL)
128
 
129
  # Convert input image to the format expected by the model
130
  input_image = input_image.convert("RGB")
 
155
  s_scale=s_scale,
156
  width=width,
157
  height=height,
158
+ guidance_scale=cfg,
159
  num_inference_steps=num_inference_steps,
160
  seed=seed,
161
  )
 
180
  with gr.Row():
181
  width = gr.Number(value=512, label="Width")
182
  height = gr.Number(value=768, label="Height")
183
+ cfg = gr.Number(value=7.5, label="CFG")
184
  with gr.Row():
185
  num_inference_steps = gr.Number(value=30, label="Number of Inference Steps", step=1, minimum=10, maximum=100)
186
  seed = gr.Number(value=2023, label="Seed")
 
188
  with gr.Row():
189
  num_images = gr.Number(value=args.num_images, label="Number of Images to Generate", step=1, minimum=1)
190
  batch_size = gr.Number(value=1, label="Batch Size", step=1)
191
+ with gr.Row():
192
+ isSDXL = gr.Checkbox(value=False, label="Activate SDXL")
193
  enable_shortcut = gr.Checkbox(value=True, label="Enable Shortcut")
194
  s_scale = gr.Number(value=1.0, label="Scale Factor (s_scale)", step=0.1, minimum=0.5, maximum=4.0)
195
  with gr.Row():
 
203
  output_gallery = gr.Gallery(label="Generated Images")
204
  output_text = gr.Textbox(label="Output Info")
205
  display_seed = gr.Textbox(label="Used Seed", interactive=False)
206
+
207
  with gr.Row():
208
  checkpoint_path_input = gr.Textbox(label="Enter Checkpoint File Path .e.g G:\model\model.safetensors", )
209
  output_path_input = gr.Textbox(label="Enter Output Folder Path, e.g. G:\model\model_diffusers")
210
  convert_btn = gr.Button("Convert Model")
211
 
 
212
  generate_btn.click(
213
  generate_image,
214
+ inputs=[input_image, positive_prompt, negative_prompt, width, height, model_selector, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path, isSDXL,cfg],
215
  outputs=[output_gallery, output_text, display_seed]
216
  )
217
 
218
  convert_btn.click(
219
  convert_model,
220
+ inputs=[checkpoint_path_input, output_path_input, isSDXL],
221
  outputs=[gr.Text(label="Conversion Status")],
222
  )
223
+
224
+ # Function to start ngrok for tunneling
225
  def start_ngrok():
226
+ print("Starting ngrok...")
227
+ time.sleep(10) # Delay to ensure Gradio starts first
 
228
  ngrok.set_auth_token(args.ngrok_token)
229
  public_url = ngrok.connect(port=7860) # Adjust to your Gradio app's port
230
  print(f"ngrok tunnel started at {public_url}")
231
 
232
  if __name__ == "__main__":
233
+ if args.ngrok_token:
234
+ # Start ngrok in a separate thread with a delay
235
+ ngrok_thread = threading.Thread(target=start_ngrok, daemon=True)
236
+ ngrok_thread.start()
237
 
238
  # Launch the Gradio app
239
  demo.launch(share=args.share, inbrowser=True)