yunquan commited on
Commit
04891fb
·
verified ·
1 Parent(s): 68bc1b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -6,36 +6,43 @@ import torch
6
  import os
7
  import logging
8
 
 
9
  logging.basicConfig(level=logging.INFO)
10
 
11
-
12
  # Retrieve Hugging Face access token from environment variables
13
  access_token = os.getenv("HF_ACCESS_TOKEN")
14
 
15
  # Set device
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
-
19
  # Global variable for the pipeline
20
  pipe = None
21
 
22
  def load_model():
23
  global pipe
24
  if pipe is None:
25
- logging.info("Loading the Stable Diffusion model...")
26
- pipe = StableDiffusionPipeline.from_pretrained(
27
- "stabilityai/stable-diffusion-3-medium",
28
- torch_dtype=torch.float16,
29
- use_auth_token=access_token,
30
- cache_dir="/path/to/cache" # specify cache directory
31
- )
32
- pipe = pipe.to(device)
33
- logging.info("Model loaded successfully.")
 
 
 
 
34
 
35
  MAX_SEED = np.iinfo(np.int32).max
36
  MAX_IMAGE_SIZE = 1024
37
 
38
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
  generator = torch.Generator().manual_seed(seed)
@@ -146,3 +153,4 @@ with gr.Blocks(css=css) as demo:
146
  )
147
 
148
  demo.queue().launch()
 
 
6
  import os
7
  import logging
8
 
9
+ # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
 
 
12
  # Retrieve Hugging Face access token from environment variables
13
  access_token = os.getenv("HF_ACCESS_TOKEN")
14
 
15
  # Set device
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
18
  # Global variable for the pipeline
19
  pipe = None
20
 
21
  def load_model():
22
  global pipe
23
  if pipe is None:
24
+ try:
25
+ logging.info("Loading the Stable Diffusion model...")
26
+ pipe = StableDiffusionPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-3-medium",
28
+ torch_dtype=torch.float16,
29
+ use_auth_token=access_token,
30
+ cache_dir="/path/to/cache" # specify cache directory if needed
31
+ )
32
+ pipe = pipe.to(device)
33
+ logging.info("Model loaded successfully.")
34
+ except Exception as e:
35
+ logging.error(f"Failed to load model: {e}")
36
+ pipe = None
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
  MAX_IMAGE_SIZE = 1024
40
 
41
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
42
+ load_model() # Ensure the model is loaded
43
+ if pipe is None:
44
+ raise RuntimeError("Model failed to load.")
45
+
46
  if randomize_seed:
47
  seed = random.randint(0, MAX_SEED)
48
  generator = torch.Generator().manual_seed(seed)
 
153
  )
154
 
155
  demo.queue().launch()
156
+