Spaces:
Runtime error
Runtime error
File size: 4,489 Bytes
395679e 5930504 395679e 9e83947 395679e 261c2c3 395679e 261c2c3 395679e 519a004 395679e 519a004 395679e 746cb43 395679e 746cb43 395679e 746cb43 395679e 746cb43 395679e 746cb43 395679e 261c2c3 5930504 395679e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import matplotlib as mpl
from model import device, segment_image, inpaint
# define utils and helpers
def closest_number(n, m=8):
""" Obtains closest number to n that is divisble by m """
return int(n/m) * m
def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
# Create a canvas component
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=2,
stroke_color="#000000",
background_image=image,
update_streamlit=True,
height=height,
width=width,
drawing_mode=drawing_mode,
point_display_radius=5,
key="canvas",
)
# get selections from mask
if canvas_result.json_data is not None:
objects = pd.json_normalize(canvas_result.json_data["objects"])
for col in objects.select_dtypes(include=["object"]).columns:
objects[col] = objects[col].astype("str")
if len(objects) > 0:
left_coords = objects.left.to_numpy()
top_coords = objects.top.to_numpy()
right_coords = left_coords + objects.width.to_numpy()
bottom_coords = top_coords + objects.height.to_numpy()
# add selections to mask
for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
cropped = image.crop((left, top, right, bottom))
st.image(cropped)
mask[top:bottom, left:right] = 255
st.header("Mask Created!")
st.image(mask)
return mask
def get_mask(image, edit_method, height, width):
mask = np.zeros((height, width), dtype=np.uint8)
if edit_method == "AutoSegment Area":
# get displayable segmented image
seg_prediction, segment_labels = segment_image(image)
seg = seg_prediction['segmentation'].cpu().numpy()
viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
st.image(seg_image)
# prompt user to select valid labels to edit
seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
if seg_selections:
tgts = []
for s in seg_selections:
tgts.append(s[0])
mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
st.header("Mask Created!")
st.image(mask)
elif edit_method == "Draw Custom Area":
mask = get_mask_from_rectangles(image, mask, height, width)
return mask
if __name__ == '__main__':
st.title("Stable Edit")
st.title("Edit your photos with Stable Diffusion!")
st.write(f"Device found: {device}")
sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2")
try:
sf = int(sf)
except:
sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
sf = 2
# upload image
filename = st.file_uploader("upload an image")
if filename:
image = Image.open(filename)
width, height = image.size
width, height = closest_number(width/sf), closest_number(height/sf)
image = image.resize((width, height))
st.image(image)
# st.write(f"{width} {height}")
# Select an editing method
edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
if edit_method:
mask = get_mask(image, edit_method, height, width)
# get inpainted images
prompt = st.text_input("Please enter prompt for image inpainting", value="")
if prompt: # and isinstance(seed, int):
st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)")
images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3)
# display all images
st.write("Original Image")
st.image(image)
for i, img in enumerate(images, 1):
st.write(f"result: {i}")
st.image(img)
|