CHEONMA010 commited on
Commit
442bd7f
·
verified ·
1 Parent(s): d28510c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -13,18 +13,19 @@ class StableBuddyApp:
13
  def __init__(self):
14
  # Set up the Stable Diffusion pipeline
15
  model_id = "CompVis/stable-diffusion-v1-4"
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # Get the auth_token from the environment variable
19
  auth_token = os.getenv("AUTH_TOKEN")
20
  if not auth_token:
21
  raise ValueError("AUTH_TOKEN environment variable is not set.")
22
 
23
- # Use float16 and fp16 so that stable diffusion can work on 4GB VRAM float 32 for cpu
 
24
  self.pipe = StableDiffusionPipeline.from_pretrained(
25
- model_id, revision='fp16', torch_dtype=torch.float32, use_auth_token=auth_token
26
  )
27
- self.pipe.to(device)
28
 
29
  def generate_image(self, prompt):
30
  """Generate an image based on the prompt."""
@@ -44,7 +45,8 @@ class StableBuddyApp:
44
 
45
  except Exception as e:
46
  print(f"An error occurred: {e}")
47
- return None
 
48
  # Create an instance of the StableBuddyApp
49
  stable_buddy_app = StableBuddyApp()
50
 
 
13
  def __init__(self):
14
  # Set up the Stable Diffusion pipeline
15
  model_id = "CompVis/stable-diffusion-v1-4"
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu" # Store device as a class attribute
17
 
18
  # Get the auth_token from the environment variable
19
  auth_token = os.getenv("AUTH_TOKEN")
20
  if not auth_token:
21
  raise ValueError("AUTH_TOKEN environment variable is not set.")
22
 
23
+ # Use float16 for GPU and float32 for CPU to manage VRAM
24
+ torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
25
  self.pipe = StableDiffusionPipeline.from_pretrained(
26
+ model_id, revision='fp16', torch_dtype=torch_dtype, use_auth_token=auth_token
27
  )
28
+ self.pipe.to(self.device)
29
 
30
  def generate_image(self, prompt):
31
  """Generate an image based on the prompt."""
 
45
 
46
  except Exception as e:
47
  print(f"An error occurred: {e}")
48
+ return None
49
+
50
  # Create an instance of the StableBuddyApp
51
  stable_buddy_app = StableBuddyApp()
52