stable_edit / app.py
itberrios's picture
updated app
395679e
raw
history blame
4.63 kB
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
from diffusers import StableDiffusionInpaintPipeline
import matplotlib as mpl
from model import segment_image, inpaint
# define utils and helpers
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def closest_number(n, m=8):
""" Obtains closest number to n that is divisble by m """
return int(n/m) * m
def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'):
# Create a canvas component
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)",
stroke_width=2,
stroke_color="#000000",
background_image=image,
update_streamlit=True,
height=height,
width=width,
drawing_mode=drawing_mode,
point_display_radius=5,
key="canvas",
)
# get selections from mask
if canvas_result.json_data is not None:
objects = pd.json_normalize(canvas_result.json_data["objects"])
for col in objects.select_dtypes(include=["object"]).columns:
objects[col] = objects[col].astype("str")
if len(objects) > 0:
# st.dataframe(objects)
left_coords = objects.left.to_numpy()
top_coords = objects.top.to_numpy()
right_coords = left_coords + objects.width.to_numpy()
bottom_coords = top_coords + objects.height.to_numpy()
# add selections to mask
for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords):
cropped = image.crop((left, top, right, bottom))
st.image(cropped)
mask[top:bottom, left:right] = 255
st.header("Mask Created!")
st.image(mask)
return mask
def get_mask(image, edit_method, height, width):
mask = np.zeros((height, width), dtype=np.uint8)
if edit_method == "AutoSegment Area":
# get displayable segmented image
seg_prediction, segment_labels = segment_image(image)
seg = seg_prediction['segmentation'].cpu().numpy()
viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg))
seg_image = Image.fromarray(np.uint8(viridis(seg)*255))
# display image
st.image(seg_image)
# prompt user to select valid labels to edit
seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values()))
if seg_selections:
tgts = []
for s in seg_selections:
tgts.append(s[0])
mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255)
st.header("Mask Created!")
st.image(mask)
elif edit_method == "Draw Custom Area":
mask = get_mask_from_rectangles(image, mask, height, width)
return mask
if __name__ == '__main__':
st.title("Stable Edit - Edit your photos with Stable Diffusion!")
# upload image
filename = st.file_uploader("upload an image")
# filename = r"C:\Users\itber\Downloads\Fjord_Cycling.jpg"
sf = st.text_input("Please enter resizing scale factor to downsize image (default = 2)", value="2")
try:
sf = int(sf)
except:
sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it")
sf = 2
if filename:
image = Image.open(filename)
width, height = image.size
width, height = closest_number(width/sf), closest_number(height/sf)
image = image.resize((width, height))
st.image(image)
# st.write(f"{width} {height}")
# Select an editing method
edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area"))
if edit_method:
mask = get_mask(image, edit_method, height, width)
# get inpainted images
prompt = st.text_input("Please enter prompt for image inpainting", value="")
st.write("Inpainting Images, patience is a virtue :)")
images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3)
# display all images
st.write("Original Image")
st.image(image)
for i, img in enumerate(images, 1):
st.write(f"result: {i}")
st.image(img)