NimaBoscarino commited on
Commit
4a51a01
1 Parent(s): 764fb00

Create Streamlit demo

Browse files
README.md CHANGED
@@ -1,45 +1,20 @@
1
  ---
2
- title: Aot Gan Inpainting
3
  emoji: 🦀
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: streamlit
7
  app_file: app.py
8
- pinned: false
 
9
  ---
10
 
11
- # Configuration
 
 
12
 
13
- `title`: _string_
14
- Display title for the Space
15
-
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
-
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
-
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
-
25
- `sdk`: _string_
26
- Can be either `gradio`, `streamlit`, or `static`
27
-
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
-
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
34
- Path is relative to the root of the repository.
35
-
36
- `models`: _List[string]_
37
- HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
38
- Will be parsed automatically from your code if not specified here.
39
-
40
- `datasets`: _List[string]_
41
- HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
42
- Will be parsed automatically from your code if not specified here.
43
-
44
- `pinned`: _boolean_
45
- Whether the Space stays on top of your list.
 
1
  ---
2
+ title: AOT-GAN Inpainting
3
  emoji: 🦀
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: streamlit
7
  app_file: app.py
8
+ pinned: true
9
+ sdk_version: 1.0.0
10
  ---
11
 
12
+ # AOT-GAN Inpainting
13
+ This space demonstrates the [AOT-GAN for High-Resolution Image Inpainting](https://github.com/researchmm/AOT-GAN-for-Inpainting) developed
14
+ by Yanhong Zeng, Jianlong Fu, Hongyang Chao, and Baining Guo. The GAN allows you to fill in large missing regions in high-resolution images.
15
 
16
+ ## Image Credit
17
+ - `man.jpg` is sourced from https://github.com/researchmm/AOT-GAN-for-Inpainting
18
+ - Photo by Ike louie Natividad from Pexels https://www.pexels.com/photo/woman-smiling-2709388/
19
+ - Photo by Christina Morillo from Pexels https://www.pexels.com/photo/woman-smiling-at-the-camera-1181686/
20
+ - Photo by Italo Melo from Pexels https://www.pexels.com/photo/man-wearing-blue-crew-neck-t-shirt-2379005/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import streamlit as st
3
+ from streamlit_drawable_canvas import st_canvas
4
+ from torchvision.transforms import ToTensor
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+ import aotgan.model.aotgan as net
9
+
10
+ @st.cache
11
+ def load_model(model_name):
12
+ model = net.InpaintGenerator.from_pretrained(model_name)
13
+ return model
14
+
15
+ def postprocess(image):
16
+ image = torch.clamp(image, -1., 1.)
17
+ image = (image + 1) / 2.0 * 255.0
18
+ image = image.permute(1, 2, 0)
19
+ image = image.cpu().numpy().astype(np.uint8)
20
+ return image
21
+
22
+ def infer(img, mask):
23
+ with torch.no_grad():
24
+ img_cv = cv2.resize(np.array(img), (512, 512)) # Fixing everything to 512 x 512 for this demo.
25
+ img_tensor = (ToTensor()(img_cv) * 2.0 - 1.0).unsqueeze(0)
26
+ mask_tensor = (ToTensor()(mask.astype(np.uint8))).unsqueeze(0)
27
+ masked_tensor = (img_tensor * (1 - mask_tensor).float()) + mask_tensor
28
+ pred_tensor = model(masked_tensor, mask_tensor)
29
+ comp_tensor = (pred_tensor * mask_tensor + img_tensor * (1 - mask_tensor))
30
+ comp_np = postprocess(comp_tensor[0])
31
+
32
+ return comp_np
33
+
34
+ stroke_width = 8
35
+ stroke_color = "#FFF"
36
+ bg_color = "#000"
37
+ bg_image = st.sidebar.file_uploader("Image:", type=["png", "jpg", "jpeg"])
38
+ sample_bg_image = st.sidebar.radio('Sample Images', [
39
+ "man.png",
40
+ "pexels-ike-louie-natividad-2709388.jpg",
41
+ "pexels-christina-morillo-1181686.jpg",
42
+ "pexels-italo-melo-2379005.jpg",
43
+ "rainbow.jpeg",
44
+ "kitty.jpg",
45
+ "kitty_on_chair.jpeg",
46
+ ])
47
+ drawing_mode = st.sidebar.selectbox(
48
+ "Drawing tool:", ("freedraw", "rect", "circle")
49
+ )
50
+
51
+ model_name = st.sidebar.selectbox(
52
+ "Select model:", ("NimaBoscarino/aot-gan-celebahq", "NimaBoscarino/aot-gan-places2")
53
+ )
54
+ model = load_model(model_name)
55
+
56
+ bg_image = Image.open(bg_image) if bg_image else Image.open(f"./pictures/{sample_bg_image}")
57
+
58
+ st.subheader("Draw on the image to erase features. The inpainted result will be generated and displayed below.")
59
+ canvas_result = st_canvas(
60
+ fill_color="rgb(255, 255, 255)",
61
+ stroke_width=stroke_width,
62
+ stroke_color=stroke_color,
63
+ background_color=bg_color,
64
+ background_image=bg_image,
65
+ update_streamlit=True,
66
+ height=512,
67
+ width=512,
68
+ drawing_mode=drawing_mode,
69
+ key="canvas",
70
+ )
71
+
72
+ if canvas_result.image_data is not None and bg_image and len(canvas_result.json_data["objects"]) > 0:
73
+ result = infer(bg_image, canvas_result.image_data[:, :, 3])
74
+ st.image(result)
pictures/kitty.jpg ADDED
pictures/kitty_on_chair.jpeg ADDED
pictures/man.png ADDED
pictures/pexels-christina-morillo-1181686.jpg ADDED
pictures/pexels-ike-louie-natividad-2709388.jpg ADDED
pictures/pexels-italo-melo-2379005.jpg ADDED
pictures/rainbow.jpeg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/NimaBoscarino/AOT-GAN-for-Inpainting.git
2
+ streamlit_drawable_canvas
3
+ opencv-python==4.5.1.48
4
+ torch==1.8.1
5
+ torchvision==0.9.1
6
+ pillow==8.1.2
7
+ transformers==4.15.0