stable_edit / app.py
itberrios's picture
update
261c2c3
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import matplotlib as mpl
from model import device, segment_image, inpaint
# define utils and helpers
def closest_number(n, m=8):
""" Obtains closest number to n that is divisble by m """
return int(n/m) * m
def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
# Create a canvas component
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=2,
stroke_color="#000000",
background_image=image,
update_streamlit=True,
height=height,
width=width,
drawing_mode=drawing_mode,
point_display_radius=5,
key="canvas",
)
# get selections from mask
if canvas_result.json_data is not None:
objects = pd.json_normalize(canvas_result.json_data["objects"])
for col in objects.select_dtypes(include=["object"]).columns:
objects[col] = objects[col].astype("str")
if len(objects) > 0:
left_coords = objects.left.to_numpy()
top_coords = objects.top.to_numpy()
right_coords = left_coords + objects.width.to_numpy()
bottom_coords = top_coords + objects.height.to_numpy()
# add selections to mask
for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
cropped = image.crop((left, top, right, bottom))
st.image(cropped)
mask[top:bottom, left:right] = 255
st.header("Mask Created!")
st.image(mask)
return mask
def get_mask(image, edit_method, height, width):
mask = np.zeros((height, width), dtype=np.uint8)
if edit_method == "AutoSegment Area":
# get displayable segmented image
seg_prediction, segment_labels = segment_image(image)
seg = seg_prediction['segmentation'].cpu().numpy()
viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
st.image(seg_image)
# prompt user to select valid labels to edit
seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
if seg_selections:
tgts = []
for s in seg_selections:
tgts.append(s[0])
mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
st.header("Mask Created!")
st.image(mask)
elif edit_method == "Draw Custom Area":
mask = get_mask_from_rectangles(image, mask, height, width)
return mask
if __name__ == '__main__':
st.title("Stable Edit")
st.title("Edit your photos with Stable Diffusion!")
st.write(f"Device found: {device}")
sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2")
try:
sf = int(sf)
except:
sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
sf = 2
# upload image
filename = st.file_uploader("upload an image")
if filename:
image = Image.open(filename)
width, height = image.size
width, height = closest_number(width/sf), closest_number(height/sf)
image = image.resize((width, height))
st.image(image)
# st.write(f"{width} {height}")
# Select an editing method
edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
if edit_method:
mask = get_mask(image, edit_method, height, width)
# get inpainted images
prompt = st.text_input("Please enter prompt for image inpainting", value="")
if prompt: # and isinstance(seed, int):
st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)")
images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3)
# display all images
st.write("Original Image")
st.image(image)
for i, img in enumerate(images, 1):
st.write(f"result: {i}")
st.image(img)