nevreal commited on
Commit
1888310
·
verified ·
1 Parent(s): 0dc5179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -3,42 +3,57 @@ from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
  # Function to automatically switch between GPU and CPU
6
- def load_model(model_id):
7
  if torch.cuda.is_available():
8
  device = "cuda"
9
  info = "Running on GPU (CUDA)"
10
  else:
11
  device = "cpu"
12
  info = "Running on CPU"
13
-
14
- # Load the model dynamically on the correct device
15
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
16
  pipe = pipe.to(device)
 
 
 
 
 
 
17
 
18
  return pipe, info
19
 
 
 
 
 
 
 
 
 
20
  # Function for text-to-image generation with dynamic model ID and device info
21
- def generate_image(model_id, prompt):
22
- pipe, info = load_model(model_id)
23
  image = pipe(prompt).images[0]
24
  return image, info
25
 
26
  # Create the Gradio interface
27
  with gr.Blocks() as demo:
28
- gr.Markdown("## Custom Text-to-Image Generator")
29
-
30
  with gr.Row():
31
  with gr.Column():
32
- model_id = gr.Textbox(label="Enter Model ID (e.g., nevreal/vMurderDrones)", placeholder="Model ID")
 
33
  prompt = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to generate")
34
  generate_btn = gr.Button("Generate Image")
35
 
36
  with gr.Column():
37
  output_image = gr.Image(label="Generated Image")
38
- device_info = gr.Markdown() # To display if GPU or CPU is used
39
 
40
  # Link the button to the image generation function
41
- generate_btn.click(fn=generate_image, inputs=[model_id, prompt], outputs=[output_image, device_info])
42
 
43
  # Launch the app
44
  demo.launch()
 
3
  import torch
4
 
5
  # Function to automatically switch between GPU and CPU
6
+ def load_model(base_model_id, adapter_model_id=None):
7
  if torch.cuda.is_available():
8
  device = "cuda"
9
  info = "Running on GPU (CUDA)"
10
  else:
11
  device = "cpu"
12
  info = "Running on CPU"
13
+
14
+ # Load the base model dynamically on the correct device
15
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
16
  pipe = pipe.to(device)
17
+
18
+ # If an adapter model is provided, load and merge the adapter model
19
+ if adapter_model_id:
20
+ adapter_model = StableDiffusionPipeline.from_pretrained(adapter_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
21
+ pipe.unet.load_attn_procs(adapter_model_id) # This applies the adapter like LoRA to the model's UNet
22
+ info += f" with Adapter Model: {adapter_model_id}"
23
 
24
  return pipe, info
25
 
26
+
27
+ if torch.cuda.is_available():
28
+ device = "cuda"
29
+ info = "Running on GPU (CUDA) 🔥"
30
+ else:
31
+ device = "cpu"
32
+ info = "Running on CPU 🥶"
33
+
34
  # Function for text-to-image generation with dynamic model ID and device info
35
+ def generate_image(base_model_id, adapter_model_id, prompt):
36
+ pipe, info = load_model(base_model_id, adapter_model_id)
37
  image = pipe(prompt).images[0]
38
  return image, info
39
 
40
  # Create the Gradio interface
41
  with gr.Blocks() as demo:
42
+ gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
43
+ gr.Markdown(f"{info}")
44
  with gr.Row():
45
  with gr.Column():
46
+ base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
47
+ adapter_model_id = gr.Textbox(label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)", placeholder="Adapter Model ID (optional)", value="")
48
  prompt = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to generate")
49
  generate_btn = gr.Button("Generate Image")
50
 
51
  with gr.Column():
52
  output_image = gr.Image(label="Generated Image")
53
+ device_info = gr.Markdown() # To display if GPU or CPU is used and whether an adapter is applied
54
 
55
  # Link the button to the image generation function
56
+ generate_btn.click(fn=generate_image, inputs=[base_model_id, adapter_model_id, prompt], outputs=[output_image, device_info])
57
 
58
  # Launch the app
59
  demo.launch()