Spaces:
Sleeping
Sleeping
load_model
Browse files
app.py
CHANGED
@@ -11,10 +11,14 @@ import random
|
|
11 |
import spaces
|
12 |
|
13 |
pipe = None
|
14 |
-
device =
|
|
|
15 |
|
16 |
def load_model():
|
17 |
-
global device
|
|
|
|
|
|
|
18 |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
19 |
"yeq6x/animagine_position_map",
|
20 |
controlnet=ControlNetModel.from_pretrained("yeq6x/Image2PositionColor_v3"),
|
@@ -23,8 +27,6 @@ def load_model():
|
|
23 |
|
24 |
return pipe
|
25 |
|
26 |
-
pipe = load_model()
|
27 |
-
|
28 |
def convert_pil_to_opencv(pil_image):
|
29 |
return np.array(pil_image)
|
30 |
|
@@ -139,9 +141,8 @@ def outpaint_image(image):
|
|
139 |
|
140 |
@spaces.GPU
|
141 |
def predict_image(cond_image, prompt, negative_prompt):
|
142 |
-
print("
|
143 |
global pipe
|
144 |
-
print("Processing...")
|
145 |
generator = torch.Generator()
|
146 |
generator.manual_seed(random.randint(0, 2147483647))
|
147 |
image = pipe(
|
@@ -160,6 +161,8 @@ def predict_image(cond_image, prompt, negative_prompt):
|
|
160 |
|
161 |
return image
|
162 |
|
|
|
|
|
163 |
# Gradioアプリケーション
|
164 |
with gr.Blocks() as demo:
|
165 |
gr.Markdown("## Position Map Visualizer")
|
|
|
11 |
import spaces
|
12 |
|
13 |
pipe = None
|
14 |
+
device = None
|
15 |
+
torch_dtype = None
|
16 |
|
17 |
def load_model():
|
18 |
+
global pipe, device, torch_dtype
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
torch_dtype = torch.float16 if device == "cuda" else torch.float32
|
21 |
+
|
22 |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
23 |
"yeq6x/animagine_position_map",
|
24 |
controlnet=ControlNetModel.from_pretrained("yeq6x/Image2PositionColor_v3"),
|
|
|
27 |
|
28 |
return pipe
|
29 |
|
|
|
|
|
30 |
def convert_pil_to_opencv(pil_image):
|
31 |
return np.array(pil_image)
|
32 |
|
|
|
141 |
|
142 |
@spaces.GPU
|
143 |
def predict_image(cond_image, prompt, negative_prompt):
|
144 |
+
print("predict position map")
|
145 |
global pipe
|
|
|
146 |
generator = torch.Generator()
|
147 |
generator.manual_seed(random.randint(0, 2147483647))
|
148 |
image = pipe(
|
|
|
161 |
|
162 |
return image
|
163 |
|
164 |
+
load_model()
|
165 |
+
|
166 |
# Gradioアプリケーション
|
167 |
with gr.Blocks() as demo:
|
168 |
gr.Markdown("## Position Map Visualizer")
|