Mrahsanahmad commited on
Commit
41c10f1
·
1 Parent(s): 8708f21

comment out extra code

Browse files
Files changed (1) hide show
  1. app.py +88 -88
app.py CHANGED
@@ -28,19 +28,19 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
28
  generator = pipeline(model="facebook/sam-vit-base", task="mask-generation", points_per_batch=256)
29
  #image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
30
 
31
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
32
- "SAMControlNet/sd-controlnet-sam-seg", dtype=jnp.float32
33
- )
34
 
35
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
36
- "runwayml/stable-diffusion-v1-5",
37
- controlnet=controlnet,
38
- revision="flax",
39
- dtype=jnp.bfloat16,
40
- )
41
 
42
- params["controlnet"] = controlnet_params
43
- p_params = replicate(params)
44
 
45
 
46
  with gr.Blocks() as demo:
@@ -83,89 +83,89 @@ with gr.Blocks() as demo:
83
 
84
  return np.stack(mask_images)
85
 
86
- def infer(
87
- image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
88
- ):
89
- try:
90
- rng = jax.random.PRNGKey(int(seed))
91
- num_inference_steps = int(num_inference_steps)
92
- image = Image.fromarray(image, mode="RGB")
93
- num_samples = max(jax.device_count(), int(num_samples))
94
- p_rng = jax.random.split(rng, jax.device_count())
95
-
96
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
97
- negative_prompt_ids = pipe.prepare_text_inputs(
98
- [negative_prompts] * num_samples
99
- )
100
- processed_image = pipe.prepare_image_inputs([image] * num_samples)
101
-
102
- prompt_ids = shard(prompt_ids)
103
- negative_prompt_ids = shard(negative_prompt_ids)
104
- processed_image = shard(processed_image)
105
-
106
- output = pipe(
107
- prompt_ids=prompt_ids,
108
- image=processed_image,
109
- params=p_params,
110
- prng_seed=p_rng,
111
- num_inference_steps=num_inference_steps,
112
- neg_prompt_ids=negative_prompt_ids,
113
- jit=True,
114
- ).images
115
-
116
- del negative_prompt_ids
117
- del processed_image
118
- del prompt_ids
119
-
120
- output = output.reshape((num_samples,) + output.shape[-3:])
121
- final_image = [np.array(x * 255, dtype=np.uint8) for x in output]
122
- print(output.shape)
123
- del output
124
-
125
- except Exception as e:
126
- print("Error: " + str(e))
127
- final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
128
- finally:
129
- gc.collect()
130
- return final_image
131
-
132
- def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
133
- img = None
134
- mask = None
135
- seg = None
136
- out = None
137
- prompt = ""
138
- neg_prompt = ""
139
- bg = False
140
- return img, mask, seg, out, prompt, neg_prompt, bg
141
 
142
  input_img.change(
143
  generate_mask,
144
  inputs=[input_img],
145
  outputs=[mask_img],
146
  )
147
- submit.click(
148
- infer,
149
- inputs=[mask_img, prompt_text, negative_prompt_text],
150
- outputs=[output_img],
151
- )
152
- clear.click(
153
- _clear,
154
- inputs=[
155
- input_img,
156
- mask_img,
157
- output_img,
158
- prompt_text,
159
- negative_prompt_text,
160
- ],
161
- outputs=[
162
- input_img,
163
- mask_img,
164
- output_img,
165
- prompt_text,
166
- negative_prompt_text,
167
- ],
168
- )
169
 
170
  if __name__ == "__main__":
171
  demo.queue()
 
28
  generator = pipeline(model="facebook/sam-vit-base", task="mask-generation", points_per_batch=256)
29
  #image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
30
 
31
+ # controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
32
+ # "SAMControlNet/sd-controlnet-sam-seg", dtype=jnp.float32
33
+ # )
34
 
35
+ # pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
36
+ # "runwayml/stable-diffusion-v1-5",
37
+ # controlnet=controlnet,
38
+ # revision="flax",
39
+ # dtype=jnp.bfloat16,
40
+ # )
41
 
42
+ # params["controlnet"] = controlnet_params
43
+ # p_params = replicate(params)
44
 
45
 
46
  with gr.Blocks() as demo:
 
83
 
84
  return np.stack(mask_images)
85
 
86
+ # def infer(
87
+ # image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
88
+ # ):
89
+ # try:
90
+ # rng = jax.random.PRNGKey(int(seed))
91
+ # num_inference_steps = int(num_inference_steps)
92
+ # image = Image.fromarray(image, mode="RGB")
93
+ # num_samples = max(jax.device_count(), int(num_samples))
94
+ # p_rng = jax.random.split(rng, jax.device_count())
95
+
96
+ # prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
97
+ # negative_prompt_ids = pipe.prepare_text_inputs(
98
+ # [negative_prompts] * num_samples
99
+ # )
100
+ # processed_image = pipe.prepare_image_inputs([image] * num_samples)
101
+
102
+ # prompt_ids = shard(prompt_ids)
103
+ # negative_prompt_ids = shard(negative_prompt_ids)
104
+ # processed_image = shard(processed_image)
105
+
106
+ # output = pipe(
107
+ # prompt_ids=prompt_ids,
108
+ # image=processed_image,
109
+ # params=p_params,
110
+ # prng_seed=p_rng,
111
+ # num_inference_steps=num_inference_steps,
112
+ # neg_prompt_ids=negative_prompt_ids,
113
+ # jit=True,
114
+ # ).images
115
+
116
+ # del negative_prompt_ids
117
+ # del processed_image
118
+ # del prompt_ids
119
+
120
+ # output = output.reshape((num_samples,) + output.shape[-3:])
121
+ # final_image = [np.array(x * 255, dtype=np.uint8) for x in output]
122
+ # print(output.shape)
123
+ # del output
124
+
125
+ # except Exception as e:
126
+ # print("Error: " + str(e))
127
+ # final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
128
+ # finally:
129
+ # gc.collect()
130
+ # return final_image
131
+
132
+ # def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
133
+ # img = None
134
+ # mask = None
135
+ # seg = None
136
+ # out = None
137
+ # prompt = ""
138
+ # neg_prompt = ""
139
+ # bg = False
140
+ # return img, mask, seg, out, prompt, neg_prompt, bg
141
 
142
  input_img.change(
143
  generate_mask,
144
  inputs=[input_img],
145
  outputs=[mask_img],
146
  )
147
+ # submit.click(
148
+ # infer,
149
+ # inputs=[mask_img, prompt_text, negative_prompt_text],
150
+ # outputs=[output_img],
151
+ # )
152
+ # clear.click(
153
+ # _clear,
154
+ # inputs=[
155
+ # input_img,
156
+ # mask_img,
157
+ # output_img,
158
+ # prompt_text,
159
+ # negative_prompt_text,
160
+ # ],
161
+ # outputs=[
162
+ # input_img,
163
+ # mask_img,
164
+ # output_img,
165
+ # prompt_text,
166
+ # negative_prompt_text,
167
+ # ],
168
+ # )
169
 
170
  if __name__ == "__main__":
171
  demo.queue()