sbicy commited on
Commit
5c89ccc
·
verified ·
1 Parent(s): b9a6225

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -3,33 +3,44 @@ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
  import torch
4
 
5
  def load_model():
6
- # Specify the Stable Diffusion pipeline with an appropriate model type
7
- pipeline = StableDiffusionPipeline.from_pretrained(
8
- "stabilityai/stable-diffusion-2-1",
9
- torch_dtype=torch.float16,
10
- revision="fp16",
11
- safety_checker=None # Disable safety checker if necessary
12
- )
13
-
14
- # Set the scheduler (optional but recommended)
 
 
 
 
15
  pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
16
-
17
- # Move pipeline to GPU or ZeroGPU
18
- pipeline = pipeline.to("cuda") # or ZeroGPU-specific setup
19
-
 
 
 
 
20
  return pipeline
21
 
22
  # Initialize the model
23
  try:
24
  model = load_model()
25
  except Exception as e:
26
- print(f"Error loading the model: {e}")
27
 
28
  # Define Gradio interface
29
  def generate(prompt, guidance_scale=7.5, num_inference_steps=50):
30
- # Generate the image
31
- images = model(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images
32
- return images[0]
 
 
 
33
 
34
  # Gradio Interface
35
  with gr.Blocks() as demo:
 
3
  import torch
4
 
5
  def load_model():
6
+ try:
7
+ # Load the model with explicit variant for half-precision weights
8
+ pipeline = StableDiffusionPipeline.from_pretrained(
9
+ "stabilityai/stable-diffusion-2-1",
10
+ torch_dtype=torch.float16,
11
+ variant="fp16", # Updated from 'revision' to 'variant'
12
+ safety_checker=None # Disable safety checker for faster inference
13
+ )
14
+ except Exception as e:
15
+ print(f"Error loading the model: {e}")
16
+ raise
17
+
18
+ # Configure the scheduler for faster generation
19
  pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
20
+
21
+ # Move to CPU if no GPU is available
22
+ try:
23
+ pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
24
+ except Exception as e:
25
+ print(f"Error moving the model to device: {e}")
26
+ raise
27
+
28
  return pipeline
29
 
30
  # Initialize the model
31
  try:
32
  model = load_model()
33
  except Exception as e:
34
+ print(f"Error initializing the model: {e}")
35
 
36
  # Define Gradio interface
37
  def generate(prompt, guidance_scale=7.5, num_inference_steps=50):
38
+ try:
39
+ # Generate image from the prompt
40
+ images = model(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images
41
+ return images[0]
42
+ except Exception as e:
43
+ return f"Error generating image: {e}"
44
 
45
  # Gradio Interface
46
  with gr.Blocks() as demo: