yeq6x commited on
Commit
2d7003c
·
1 Parent(s): 2189162

load_model

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -11,10 +11,14 @@ import random
11
  import spaces
12
 
13
  pipe = None
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
 
16
  def load_model():
17
- global device
 
 
 
18
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
19
  "yeq6x/animagine_position_map",
20
  controlnet=ControlNetModel.from_pretrained("yeq6x/Image2PositionColor_v3"),
@@ -23,8 +27,6 @@ def load_model():
23
 
24
  return pipe
25
 
26
- pipe = load_model()
27
-
28
  def convert_pil_to_opencv(pil_image):
29
  return np.array(pil_image)
30
 
@@ -139,9 +141,8 @@ def outpaint_image(image):
139
 
140
  @spaces.GPU
141
  def predict_image(cond_image, prompt, negative_prompt):
142
- print("Processing...")
143
  global pipe
144
- print("Processing...")
145
  generator = torch.Generator()
146
  generator.manual_seed(random.randint(0, 2147483647))
147
  image = pipe(
@@ -160,6 +161,8 @@ def predict_image(cond_image, prompt, negative_prompt):
160
 
161
  return image
162
 
 
 
163
  # Gradioアプリケーション
164
  with gr.Blocks() as demo:
165
  gr.Markdown("## Position Map Visualizer")
 
11
  import spaces
12
 
13
  pipe = None
14
+ device = None
15
+ torch_dtype = None
16
 
17
  def load_model():
18
+ global pipe, device, torch_dtype
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
21
+
22
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
23
  "yeq6x/animagine_position_map",
24
  controlnet=ControlNetModel.from_pretrained("yeq6x/Image2PositionColor_v3"),
 
27
 
28
  return pipe
29
 
 
 
30
  def convert_pil_to_opencv(pil_image):
31
  return np.array(pil_image)
32
 
 
141
 
142
  @spaces.GPU
143
  def predict_image(cond_image, prompt, negative_prompt):
144
+ print("predict position map")
145
  global pipe
 
146
  generator = torch.Generator()
147
  generator.manual_seed(random.randint(0, 2147483647))
148
  image = pipe(
 
161
 
162
  return image
163
 
164
+ load_model()
165
+
166
  # Gradioアプリケーション
167
  with gr.Blocks() as demo:
168
  gr.Markdown("## Position Map Visualizer")