eienmojiki commited on
Commit
6dbd85a
·
verified ·
1 Parent(s): 517bc28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -21
app.py CHANGED
@@ -61,20 +61,6 @@ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
61
  }
62
  return scheduler_factory_map.get(name, lambda: None)()
63
 
64
- def load_pipeline(model_name):
65
- if torch.cuda.is_available():
66
- pipe = StableDiffusionXLPipeline.from_pretrained(
67
- model_name,
68
- torch_dtype=torch.float16,
69
- custom_pipeline="lpw_stable_diffusion_xl",
70
- safety_checker = None,
71
- use_safetensors=True,
72
- add_watermarker=False,
73
- use_auth_token=HF_TOKEN
74
- )
75
- pipe.to(device)
76
- return pipe
77
-
78
  @spaces.GPU
79
  def generate(
80
  prompt: str,
@@ -88,6 +74,17 @@ def generate(
88
  clip_skip: int = 1,
89
  progress=gr.Progress(track_tqdm=True),
90
  ):
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  generator = seed_everything(seed)
93
  pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
@@ -116,13 +113,6 @@ def generate(
116
  except Exception as e:
117
  print(f"An error occurred: {e}")
118
 
119
- if torch.cuda.is_available():
120
- pipe = load_pipeline(MODEL)
121
- pipe.to(device)
122
- print("Loaded on Device!")
123
- else:
124
- pipe = None
125
-
126
  with gr.Blocks(
127
  theme=gr.themes.Soft()
128
  ) as demo:
 
61
  }
62
  return scheduler_factory_map.get(name, lambda: None)()
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @spaces.GPU
65
  def generate(
66
  prompt: str,
 
74
  clip_skip: int = 1,
75
  progress=gr.Progress(track_tqdm=True),
76
  ):
77
+ if torch.cuda.is_available():
78
+ pipe = StableDiffusionXLPipeline.from_pretrained(
79
+ MODEL,
80
+ torch_dtype=torch.float16,
81
+ custom_pipeline="lpw_stable_diffusion_xl",
82
+ safety_checker=None,
83
+ use_safetensors=True,
84
+ add_watermarker=False,
85
+ use_auth_token=HF_TOKEN
86
+ )
87
+ pipe.to(device)
88
 
89
  generator = seed_everything(seed)
90
  pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
 
113
  except Exception as e:
114
  print(f"An error occurred: {e}")
115
 
 
 
 
 
 
 
 
116
  with gr.Blocks(
117
  theme=gr.themes.Soft()
118
  ) as demo: