Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,26 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import imageio
|
4 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
source_img = gr.Image(source="upload", type="numpy", tool="sketch", elem_id="source_container");
|
7 |
-
|
|
|
8 |
def resize(height,img):
|
9 |
baseheight = height
|
10 |
img = Image.open(img)
|
@@ -14,12 +30,6 @@ def resize(height,img):
|
|
14 |
return img
|
15 |
|
16 |
def predict(source_img):
|
17 |
-
|
18 |
-
#print(sketch)
|
19 |
-
#print(sketch.mode)
|
20 |
-
#sketch_png = resize(512,source_img)
|
21 |
-
#sketch_png.save('source.png')
|
22 |
-
#print(sketch_png)
|
23 |
imageio.imwrite("data.png", source_img["image"])
|
24 |
imageio.imwrite("data_mask.png", source_img["mask"])
|
25 |
|
@@ -27,8 +37,17 @@ def predict(source_img):
|
|
27 |
src.save("src.png")
|
28 |
mask = resize(512, "data_mask.png")
|
29 |
mask.save("mask.png")
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
custom_css="style.css"
|
33 |
|
34 |
-
gr.Interface(fn=predict, inputs=source_img, outputs=
|
|
|
1 |
+
from diffusers import StableDiffusionInpaintPipeline
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import imageio
|
5 |
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import os
|
8 |
+
|
9 |
+
MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
|
10 |
+
|
11 |
+
|
12 |
+
print("hello sylvain")
|
13 |
+
|
14 |
+
YOUR_TOKEN=MY_SECRET_TOKEN
|
15 |
+
|
16 |
+
device="cpu"
|
17 |
+
|
18 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
|
19 |
+
pipe.to(device)
|
20 |
|
21 |
source_img = gr.Image(source="upload", type="numpy", tool="sketch", elem_id="source_container");
|
22 |
+
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
|
23 |
+
|
24 |
def resize(height,img):
|
25 |
baseheight = height
|
26 |
img = Image.open(img)
|
|
|
30 |
return img
|
31 |
|
32 |
def predict(source_img):
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
imageio.imwrite("data.png", source_img["image"])
|
34 |
imageio.imwrite("data_mask.png", source_img["mask"])
|
35 |
|
|
|
37 |
src.save("src.png")
|
38 |
mask = resize(512, "data_mask.png")
|
39 |
mask.save("mask.png")
|
40 |
+
|
41 |
+
images_list = img_pipe([prompt] * 1, init_image=src, mask_image=mask, strength=0.75)
|
42 |
+
images = []
|
43 |
+
safe_image = Image.open(r"unsafe.png")
|
44 |
+
for i, image in enumerate(images_list["sample"]):
|
45 |
+
if(images_list["nsfw_content_detected"][i]):
|
46 |
+
images.append(safe_image)
|
47 |
+
else:
|
48 |
+
images.append(image)
|
49 |
+
return images
|
50 |
|
51 |
custom_css="style.css"
|
52 |
|
53 |
+
gr.Interface(fn=predict, inputs=source_img, outputs=gallery, css=custom_css).launch(enable_queue=True)
|