Spaces:
Runtime error
Runtime error
File size: 4,807 Bytes
395679e 5930504 395679e 9e83947 53e96bd 746cb43 395679e 746cb43 395679e 746cb43 395679e 746cb43 395679e 746cb43 395679e 746cb43 395679e 0ae24ac 746cb43 395679e 5930504 0ae24ac 5930504 395679e 5930504 395679e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from diffusers import StableDiffusionInpaintPipeline
import matplotlib as mpl
from model import segment_image, inpaint
# define utils and helpers
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def closest_number(n, m=8):
""" Obtains closest number to n that is divisble by m """
return int(n/m) * m
def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
# Create a canvas component
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=2,
stroke_color="#000000",
background_image=image,
update_streamlit=True,
height=height,
width=width,
drawing_mode=drawing_mode,
point_display_radius=5,
key="canvas",
)
# get selections from mask
if canvas_result.json_data is not None:
objects = pd.json_normalize(canvas_result.json_data["objects"])
for col in objects.select_dtypes(include=["object"]).columns:
objects[col] = objects[col].astype("str")
if len(objects) > 0:
left_coords = objects.left.to_numpy()
top_coords = objects.top.to_numpy()
right_coords = left_coords + objects.width.to_numpy()
bottom_coords = top_coords + objects.height.to_numpy()
# add selections to mask
for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
cropped = image.crop((left, top, right, bottom))
st.image(cropped)
mask[top:bottom, left:right] = 255
st.header("Mask Created!")
st.image(mask)
return mask
def get_mask(image, edit_method, height, width):
mask = np.zeros((height, width), dtype=np.uint8)
if edit_method == "AutoSegment Area":
# get displayable segmented image
seg_prediction, segment_labels = segment_image(image)
seg = seg_prediction['segmentation'].cpu().numpy()
viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
st.image(seg_image)
# prompt user to select valid labels to edit
seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
if seg_selections:
tgts = []
for s in seg_selections:
tgts.append(s[0])
mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
st.header("Mask Created!")
st.image(mask)
elif edit_method == "Draw Custom Area":
mask = get_mask_from_rectangles(image, mask, height, width)
return mask
if __name__ == '__main__':
st.title("Stable Edit - Edit your photos with Stable Diffusion!")
# upload image
filename = st.file_uploader("upload an image")
sf = st.text_input("Please enter resizing scale factor to downsize image (default = 2)", value="2")
try:
sf = int(sf)
except:
sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
sf = 2
if filename:
image = Image.open(filename)
width, height = image.size
width, height = closest_number(width/sf), closest_number(height/sf)
image = image.resize((width, height))
st.image(image)
# st.write(f"{width} {height}")
# Select an editing method
edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
if edit_method:
mask = get_mask(image, edit_method, height, width)
# get inpainted images
prompt = st.text_input("Please enter prompt for image inpainting", value="")
seed = st.text_input("(Optional) enter seed to change inpainting result (default=0)", value="0")
try:
seed = int(seed)
except:
st.write("Invalid seed! Defaultign to 0, please re-enter above to change it")
seed = 0
st.write("Inpainting Images, patience is a virtue :)")
images = inpaint(image, mask, width, height, prompt=prompt, seed=seed, guidance_scale=17.5, num_samples=3)
# display all images
st.write("Original Image")
st.image(image)
for i, img in enumerate(images, 1):
st.write(f"result: {i}")
st.image(img)
|