Spaces:
Runtime error
Runtime error
Commit
·
41c10f1
1
Parent(s):
8708f21
comment out extra code
Browse files
app.py
CHANGED
@@ -28,19 +28,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
28 |
generator = pipeline(model="facebook/sam-vit-base", task="mask-generation", points_per_batch=256)
|
29 |
#image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
30 |
|
31 |
-
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
32 |
-
|
33 |
-
)
|
34 |
|
35 |
-
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
)
|
41 |
|
42 |
-
params["controlnet"] = controlnet_params
|
43 |
-
p_params = replicate(params)
|
44 |
|
45 |
|
46 |
with gr.Blocks() as demo:
|
@@ -83,89 +83,89 @@ with gr.Blocks() as demo:
|
|
83 |
|
84 |
return np.stack(mask_images)
|
85 |
|
86 |
-
def infer(
|
87 |
-
|
88 |
-
):
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
|
142 |
input_img.change(
|
143 |
generate_mask,
|
144 |
inputs=[input_img],
|
145 |
outputs=[mask_img],
|
146 |
)
|
147 |
-
submit.click(
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
)
|
152 |
-
clear.click(
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
)
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
demo.queue()
|
|
|
28 |
generator = pipeline(model="facebook/sam-vit-base", task="mask-generation", points_per_batch=256)
|
29 |
#image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
30 |
|
31 |
+
# controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
|
32 |
+
# "SAMControlNet/sd-controlnet-sam-seg", dtype=jnp.float32
|
33 |
+
# )
|
34 |
|
35 |
+
# pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
36 |
+
# "runwayml/stable-diffusion-v1-5",
|
37 |
+
# controlnet=controlnet,
|
38 |
+
# revision="flax",
|
39 |
+
# dtype=jnp.bfloat16,
|
40 |
+
# )
|
41 |
|
42 |
+
# params["controlnet"] = controlnet_params
|
43 |
+
# p_params = replicate(params)
|
44 |
|
45 |
|
46 |
with gr.Blocks() as demo:
|
|
|
83 |
|
84 |
return np.stack(mask_images)
|
85 |
|
86 |
+
# def infer(
|
87 |
+
# image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
|
88 |
+
# ):
|
89 |
+
# try:
|
90 |
+
# rng = jax.random.PRNGKey(int(seed))
|
91 |
+
# num_inference_steps = int(num_inference_steps)
|
92 |
+
# image = Image.fromarray(image, mode="RGB")
|
93 |
+
# num_samples = max(jax.device_count(), int(num_samples))
|
94 |
+
# p_rng = jax.random.split(rng, jax.device_count())
|
95 |
+
|
96 |
+
# prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
|
97 |
+
# negative_prompt_ids = pipe.prepare_text_inputs(
|
98 |
+
# [negative_prompts] * num_samples
|
99 |
+
# )
|
100 |
+
# processed_image = pipe.prepare_image_inputs([image] * num_samples)
|
101 |
+
|
102 |
+
# prompt_ids = shard(prompt_ids)
|
103 |
+
# negative_prompt_ids = shard(negative_prompt_ids)
|
104 |
+
# processed_image = shard(processed_image)
|
105 |
+
|
106 |
+
# output = pipe(
|
107 |
+
# prompt_ids=prompt_ids,
|
108 |
+
# image=processed_image,
|
109 |
+
# params=p_params,
|
110 |
+
# prng_seed=p_rng,
|
111 |
+
# num_inference_steps=num_inference_steps,
|
112 |
+
# neg_prompt_ids=negative_prompt_ids,
|
113 |
+
# jit=True,
|
114 |
+
# ).images
|
115 |
+
|
116 |
+
# del negative_prompt_ids
|
117 |
+
# del processed_image
|
118 |
+
# del prompt_ids
|
119 |
+
|
120 |
+
# output = output.reshape((num_samples,) + output.shape[-3:])
|
121 |
+
# final_image = [np.array(x * 255, dtype=np.uint8) for x in output]
|
122 |
+
# print(output.shape)
|
123 |
+
# del output
|
124 |
+
|
125 |
+
# except Exception as e:
|
126 |
+
# print("Error: " + str(e))
|
127 |
+
# final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
|
128 |
+
# finally:
|
129 |
+
# gc.collect()
|
130 |
+
# return final_image
|
131 |
+
|
132 |
+
# def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
|
133 |
+
# img = None
|
134 |
+
# mask = None
|
135 |
+
# seg = None
|
136 |
+
# out = None
|
137 |
+
# prompt = ""
|
138 |
+
# neg_prompt = ""
|
139 |
+
# bg = False
|
140 |
+
# return img, mask, seg, out, prompt, neg_prompt, bg
|
141 |
|
142 |
input_img.change(
|
143 |
generate_mask,
|
144 |
inputs=[input_img],
|
145 |
outputs=[mask_img],
|
146 |
)
|
147 |
+
# submit.click(
|
148 |
+
# infer,
|
149 |
+
# inputs=[mask_img, prompt_text, negative_prompt_text],
|
150 |
+
# outputs=[output_img],
|
151 |
+
# )
|
152 |
+
# clear.click(
|
153 |
+
# _clear,
|
154 |
+
# inputs=[
|
155 |
+
# input_img,
|
156 |
+
# mask_img,
|
157 |
+
# output_img,
|
158 |
+
# prompt_text,
|
159 |
+
# negative_prompt_text,
|
160 |
+
# ],
|
161 |
+
# outputs=[
|
162 |
+
# input_img,
|
163 |
+
# mask_img,
|
164 |
+
# output_img,
|
165 |
+
# prompt_text,
|
166 |
+
# negative_prompt_text,
|
167 |
+
# ],
|
168 |
+
# )
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
demo.queue()
|