Samseisun commited on
Commit
e061e2d
·
verified ·
1 Parent(s): e58f2d5

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -35
  2. README.md +13 -12
  3. app.py +138 -0
  4. model.py +86 -0
  5. requirements.txt +9 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
- ---
2
- title: Space1
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: streamlit
7
- sdk_version: 1.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
+ ---
2
+ title: Stable Edit
3
+ emoji: 📊
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: streamlit
7
+ sdk_version: 1.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from PIL import Image
4
+ from collections import defaultdict
5
+
6
+ import streamlit as st
7
+ from streamlit_drawable_canvas import st_canvas
8
+
9
+ import matplotlib as mpl
10
+
11
+ from model import device, segment_image, inpaint
12
+
13
+
14
+ # define utils and helpers
15
+ def closest_number(n, m=8):
16
+ """ Obtains closest number to n that is divisble by m """
17
+ return int(n/m) * m
18
+
19
+
20
+ def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
21
+ # Create a canvas component
22
+ canvas_result = st_canvas(
23
+ fill_color="rgba(255, 165, 0, 0.3)",
24
+ stroke_width=2,
25
+ stroke_color="#000000",
26
+ background_image=image,
27
+ update_streamlit=True,
28
+ height=height,
29
+ width=width,
30
+ drawing_mode=drawing_mode,
31
+ point_display_radius=5,
32
+ key="canvas",
33
+ )
34
+
35
+ # get selections from mask
36
+ if canvas_result.json_data is not None:
37
+ objects = pd.json_normalize(canvas_result.json_data["objects"])
38
+ for col in objects.select_dtypes(include=["object"]).columns:
39
+ objects[col] = objects[col].astype("str")
40
+
41
+ if len(objects) > 0:
42
+ left_coords = objects.left.to_numpy()
43
+ top_coords = objects.top.to_numpy()
44
+ right_coords = left_coords + objects.width.to_numpy()
45
+ bottom_coords = top_coords + objects.height.to_numpy()
46
+
47
+ # add selections to mask
48
+ for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
49
+ cropped = image.crop((left, top, right, bottom))
50
+ st.image(cropped)
51
+ mask[top:bottom, left:right] = 255
52
+
53
+ st.header("Mask Created!")
54
+ st.image(mask)
55
+
56
+ return mask
57
+
58
+
59
+ def get_mask(image, edit_method, height, width):
60
+ mask = np.zeros((height, width), dtype=np.uint8)
61
+
62
+ if edit_method == "AutoSegment Area":
63
+
64
+ # get displayable segmented image
65
+ seg_prediction, segment_labels = segment_image(image)
66
+ seg = seg_prediction['segmentation'].cpu().numpy()
67
+ viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
68
+ seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
69
+
70
+ st.image(seg_image)
71
+
72
+ # prompt user to select valid labels to edit
73
+ seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
74
+ if seg_selections:
75
+ tgts = []
76
+ for s in seg_selections:
77
+ tgts.append(s[0])
78
+
79
+ mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
80
+ st.header("Mask Created!")
81
+ st.image(mask)
82
+
83
+ elif edit_method == "Draw Custom Area":
84
+ mask = get_mask_from_rectangles(image, mask, height, width)
85
+
86
+
87
+ return mask
88
+
89
+
90
+
91
+ if __name__ == '__main__':
92
+
93
+ st.title("Stable Edit")
94
+ st.title("Edit your photos with Stable Diffusion!")
95
+
96
+ st.write(f"Device found: {device}")
97
+
98
+ sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2")
99
+ try:
100
+ sf = int(sf)
101
+ except:
102
+ sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
103
+ sf = 2
104
+
105
+ # upload image
106
+ filename = st.file_uploader("upload an image")
107
+
108
+ if filename:
109
+ image = Image.open(filename)
110
+
111
+ width, height = image.size
112
+ width, height = closest_number(width/sf), closest_number(height/sf)
113
+ image = image.resize((width, height))
114
+
115
+ st.image(image)
116
+ # st.write(f"{width} {height}")
117
+
118
+ # Select an editing method
119
+ edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
120
+
121
+ if edit_method:
122
+ mask = get_mask(image, edit_method, height, width)
123
+
124
+ # get inpainted images
125
+ prompt = st.text_input("Please enter prompt for image inpainting", value="")
126
+
127
+ if prompt: # and isinstance(seed, int):
128
+ st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)")
129
+ images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3)
130
+
131
+ # display all images
132
+ st.write("Original Image")
133
+ st.image(image)
134
+ for i, img in enumerate(images, 1):
135
+ st.write(f"result: {i}")
136
+ st.image(img)
137
+
138
+
model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
3
+ from diffusers import StableDiffusionInpaintPipeline # , DiffusionPipeline
4
+
5
+
6
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
+
8
+
9
+ # Image segmentation
10
+ seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
11
+ seg_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
12
+
13
+ def segment_image(image):
14
+ inputs = seg_processor(image, return_tensors="pt")
15
+
16
+ with torch.no_grad():
17
+ seg_outputs = seg_model(**inputs)
18
+
19
+ # get prediction dict
20
+ seg_prediction = seg_processor.post_process_panoptic_segmentation(seg_outputs, target_sizes=[image.size[::-1]])[0]
21
+
22
+ # get segment labels dict
23
+ segment_labels = {}
24
+ for segment in seg_prediction['segments_info']:
25
+ segment_id = segment['id']
26
+ segment_label_id = segment['label_id']
27
+ segment_label = seg_model.config.id2label[segment_label_id]
28
+
29
+ segment_labels.update({segment_id : segment_label})
30
+
31
+ return seg_prediction, segment_labels
32
+
33
+ # Image inpainting pipeline
34
+ # get Stable Diffusion model for image inpainting
35
+ if device == 'cuda':
36
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
37
+ "runwayml/stable-diffusion-inpainting",
38
+ torch_dtype=torch.float16,
39
+ ).to(device)
40
+ else:
41
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
42
+ "runwayml/stable-diffusion-inpainting",
43
+ torch_dtype=torch.bfloat16,
44
+ device_map="auto"
45
+ )
46
+
47
+ # pipe = StableDiffusionInpaintPipeline.from_pretrained( # DiffusionPipeline.from_pretrained(
48
+ # "runwayml/stable-diffusion-inpainting",
49
+ # revision="fp16",
50
+ # torch_dtype=torch.bfloat16,
51
+ # # device_map="auto" # use for Hugging face spaces
52
+ # )
53
+ # pipe.to(device) # use for local environment
54
+
55
+ def inpaint(image, mask, W, H, prompt="", seed=0, guidance_scale=17.5, num_samples=3):
56
+ """ Uses Stable Diffusion model to inpaint image
57
+ Inputs:
58
+ image - input image (PIL or torch tensor)
59
+ mask - mask for inpainting same size as image (PIL or troch tensor)
60
+ W - size of image
61
+ H - size of mask
62
+ prompt - prompt for inpainting
63
+ seed - random seed
64
+ Outputs:
65
+ images - output images
66
+ """
67
+ generator = torch.Generator(device=device).manual_seed(seed)
68
+ images = pipe(
69
+ prompt=prompt,
70
+ image=image,
71
+ mask_image=mask, # ensure mask is same type as image
72
+ height=H,
73
+ width=W,
74
+ guidance_scale=guidance_scale,
75
+ generator=generator,
76
+ num_images_per_prompt=num_samples,
77
+ ).images
78
+
79
+ return images
80
+
81
+
82
+
83
+
84
+
85
+
86
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pillow
2
+ numpy
3
+ scipy
4
+ matplotlib
5
+ streamlit-drawable-canvas
6
+ accelerate
7
+ torch==2.0.1
8
+ transformers==4.30.2
9
+ diffusers==0.11.1