import numpy as np import pandas as pd from PIL import Image from collections import defaultdict import streamlit as st from streamlit_drawable_canvas import st_canvas import matplotlib as mpl from model import device, segment_image, inpaint # define utils and helpers def closest_number(n, m=8): """ Obtains closest number to n that is divisble by m """ return int(n/m) * m def get_mask_from_rectangles(image, mask, height, width, drawing_mode='rect'): # Create a canvas component canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", stroke_width=2, stroke_color="#000000", background_image=image, update_streamlit=True, height=height, width=width, drawing_mode=drawing_mode, point_display_radius=5, key="canvas", ) # get selections from mask if canvas_result.json_data is not None: objects = pd.json_normalize(canvas_result.json_data["objects"]) for col in objects.select_dtypes(include=["object"]).columns: objects[col] = objects[col].astype("str") if len(objects) > 0: left_coords = objects.left.to_numpy() top_coords = objects.top.to_numpy() right_coords = left_coords + objects.width.to_numpy() bottom_coords = top_coords + objects.height.to_numpy() # add selections to mask for (left, top, right, bottom) in zip(left_coords, top_coords, right_coords, bottom_coords): cropped = image.crop((left, top, right, bottom)) st.image(cropped) mask[top:bottom, left:right] = 255 st.header("Mask Created!") st.image(mask) return mask def get_mask(image, edit_method, height, width): mask = np.zeros((height, width), dtype=np.uint8) if edit_method == "AutoSegment Area": # get displayable segmented image seg_prediction, segment_labels = segment_image(image) seg = seg_prediction['segmentation'].cpu().numpy() viridis = mpl.colormaps.get_cmap('viridis').resampled(np.max(seg)) seg_image = Image.fromarray(np.uint8(viridis(seg)*255)) st.image(seg_image) # prompt user to select valid labels to edit seg_selections = st.multiselect("Choose segments", zip(segment_labels.keys(), segment_labels.values())) if seg_selections: tgts = [] for s in seg_selections: tgts.append(s[0]) mask = Image.fromarray(np.array([(seg == t) for t in tgts]).sum(axis=0).astype(np.uint8)*255) st.header("Mask Created!") st.image(mask) elif edit_method == "Draw Custom Area": mask = get_mask_from_rectangles(image, mask, height, width) return mask if __name__ == '__main__': st.title("Stable Edit") st.title("Edit your photos with Stable Diffusion!") st.write(f"Device found: {device}") sf = st.text_input("Please enter resizing scale factor to downsize image (default=2)", value="2") try: sf = int(sf) except: sf.write("Error with input scale factor, setting to default value of 2, please re-enter above to change it") sf = 2 # upload image filename = st.file_uploader("upload an image") if filename: image = Image.open(filename) width, height = image.size width, height = closest_number(width/sf), closest_number(height/sf) image = image.resize((width, height)) st.image(image) # st.write(f"{width} {height}") # Select an editing method edit_method = st.selectbox("Select Edit Method", ("AutoSegment Area", "Draw Custom Area")) if edit_method: mask = get_mask(image, edit_method, height, width) # get inpainted images prompt = st.text_input("Please enter prompt for image inpainting", value="") if prompt: # and isinstance(seed, int): st.write("Inpainting Images, patience is a virtue and this will take a while to run on a CPU :)") images = inpaint(image, mask, width, height, prompt=prompt, seed=0, guidance_scale=17.5, num_samples=3) # display all images st.write("Original Image") st.image(image) for i, img in enumerate(images, 1): st.write(f"result: {i}") st.image(img)