import streamlit as st from streamlit_drawable_canvas import st_canvas from PIL import Image import numpy as np import matplotlib.pyplot as plt import image_mask_gen import torch from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor import os import io import warnings from stability_sdk import client import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation import streamlit as st import base64 # Function to display points on the image using matplotlib def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def remove_duplicates(coords, labels): unique_coords = [] unique_labels = [] seen = set() for coord, label in zip(coords, labels): coord_tuple = tuple(coord) if coord_tuple not in seen: seen.add(coord_tuple) unique_coords.append(coord) unique_labels.append(label) return unique_coords, unique_labels def image_augmentation_page(): pass st.title("Image Augmentation") st.write("Upload an image to apply augmentation techniques.") # Initialize session state variables if "inclusive_points" not in st.session_state: st.session_state.inclusive_points = [] if "exclusive_points" not in st.session_state: st.session_state.exclusive_points = [] # Upload an image uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: # Open the uploaded image image = Image.open(uploaded_file) # Set the maximum width for display max_display_width = 700 # You can adjust this value # Calculate the scaling factor scale_factor = min(max_display_width / image.size[0], 1) # Resize the image for display display_width = int(image.size[0] * scale_factor) display_height = int(image.size[1] * scale_factor) resized_image = image.resize((display_width, display_height)) # Inclusive Points Phase st.subheader("Select Inclusive Points (Green)") canvas_inclusive = st_canvas( fill_color="rgba(0, 0, 0, 0)", # Transparent fill stroke_width=1, # Stroke width for drawing stroke_color="blue", # Color for the outline of clicks background_image=resized_image, update_streamlit=True, height=display_height, width=display_width, drawing_mode="circle", # Drawing mode to capture clicks as circles point_display_radius=3, # Radius of the circle that represents a click key="canvas_inclusive" ) # Process inclusive clicks if canvas_inclusive.json_data is not None: objects = canvas_inclusive.json_data["objects"] new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects] st.session_state.inclusive_points.extend(new_clicks) # Plot the inclusive points on the original image using Matplotlib fig_inclusive, ax = plt.subplots() ax.imshow(image) ax.axis('off') # Hide the axes # Prepare data for plotting inclusive_points = np.array(st.session_state.inclusive_points) labels_inclusive = np.array([1] * len(st.session_state.inclusive_points)) # Call the function to show inclusive points if len(inclusive_points) > 0: show_points(inclusive_points, labels_inclusive, ax) st.pyplot(fig_inclusive) # Divider st.divider() # Exclusive Points Phase st.subheader("Select Exclusive Points (Red)") canvas_exclusive = st_canvas( fill_color="rgba(0, 0, 0, 0)", # Transparent fill stroke_width=1, # Stroke width for drawing stroke_color="blue", # Color for the outline of clicks background_image=resized_image, update_streamlit=True, height=display_height, width=display_width, drawing_mode="circle", # Drawing mode to capture clicks as circles point_display_radius=3, # Radius of the circle that represents a click key="canvas_exclusive" ) # Process exclusive clicks if canvas_exclusive.json_data is not None: objects = canvas_exclusive.json_data["objects"] new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects] st.session_state.exclusive_points.extend(new_clicks) # Plot the exclusive points on the original image using Matplotlib fig_exclusive, ax = plt.subplots() ax.imshow(image) ax.axis('off') # Hide the axes # Prepare data for plotting exclusive_points = np.array(st.session_state.exclusive_points) labels_exclusive = np.array([0] * len(st.session_state.exclusive_points)) # Call the function to show exclusive points if len(exclusive_points) > 0: show_points(exclusive_points, labels_exclusive, ax) st.pyplot(fig_exclusive) # Grouping coordinates and labels coordinates = st.session_state.inclusive_points + st.session_state.exclusive_points labels = [1] * len(st.session_state.inclusive_points) + [0] * len(st.session_state.exclusive_points) # # Display grouped coordinates and labels # st.subheader("Coordinates and Labels") # st.write("Coordinates: ", tuple(coordinates)) # st.write("Labels: ", labels) # Provide an option to clear the coordinates if st.button("Clear All Points"): st.session_state.inclusive_points = [] st.session_state.exclusive_points = [] # global unique_coordinates, unique_labels unique_coordinates, unique_labels = remove_duplicates(coordinates, labels) st.write("Unique Coordinates:", tuple(unique_coordinates)) st.write("Unique Labels:", tuple(unique_labels)) # image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label) sam2_checkpoint = "sam2_hiera_base_plus.pt" model_cfg = "sam2_hiera_b+.yaml" sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu") predictor = SAM2ImagePredictor(sam2_model) image = image predictor.set_image(image) input_point = np.array(unique_coordinates) input_label = np.array(unique_labels) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) sorted_ind = np.argsort(scores)[::-1] masks = masks[sorted_ind] scores = scores[sorted_ind] logits = logits[sorted_ind] mask_input = logits[np.argmax(scores), :, :] masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label) # Get masked images original_image = Image.open(uploaded_file) # st.image(original_image, caption='Original Image', use_column_width=True) with st.container(border=True):# Display masked images col1, col2 = st.columns(2) with col1: mask_images = image_mask_gen.show_masks_1(original_image, masks, scores) for idx, (img, score) in enumerate(mask_images): st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True) with col2: inverse_mask_images = image_mask_gen.show_inverse_masks(original_image, masks, scores) for idx, (img, score) in enumerate(inverse_mask_images): st.image(img, caption=f'Inverse Mask {idx+1}, Score: {score:.3f}', use_column_width=True) if st.checkbox("Proceed to Image Augmentation"): image_aug_select = st.sidebar.selectbox("Select Augmentation for Mask",["Pixelate","Hue Change","Mask Replacement","Generative Img2Img"]) if image_aug_select == "Pixelate": if st.sidebar.toggle("Proceed to Pixelate Mask"): pixelation_level = st.slider("Select Pixelation Level", min_value=5, max_value=50, value=10) combined_image = image_mask_gen.combine_pixelated_mask(original_image, masks[0], pixelation_level) st.image(combined_image, caption="Combined Pixelated Image", use_column_width=True) elif image_aug_select == "Hue Change": if st.sidebar.toggle("Proceed to Hue Change"): # Hue shift slider hue_shift = st.slider("Select Hue Shift", min_value=-180, max_value=180, value=0) # Apply hue change and show the result combined_image = image_mask_gen.combine_hue_changed_mask(original_image, masks[0], hue_shift) # Assuming single mask st.image(combined_image, caption="Combined Hue Changed Image", use_column_width=True) elif image_aug_select == "Mask Replacement": if st.sidebar.toggle("Proceed to replace Mask"): replacement_file = st.file_uploader("Upload the replacement image", type=["png", "jpg", "jpeg"]) if replacement_file is not None: replacement_image = Image.open(replacement_file) #.convert("RGBA") combined_image = image_mask_gen.combine_mask_replaced_image(original_image, replacement_image, masks[0]) # Assuming single mask st.image(combined_image, caption="Masked Area Replaced Image", use_column_width=True) elif image_aug_select == "Generative Img2Img": msk_img = None mask_images_x = image_mask_gen.show_masks_1(original_image, masks, scores) for idx, (img, score) in enumerate(mask_images_x): msk_img = img # st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True) rgb_image = msk_img.convert("RGB") # st.image(rgb_image) resized_image = image_mask_gen.resize_image(rgb_image) # st.image(resized_image, caption=f"Resized size: {resized_image.size[0]}x{resized_image.size[1]}", use_column_width=True) width, height = resized_image.size # User input for the prompt and API key prompt = st.text_input("Enter your prompt:", "A Beautiful day, in the style reference of starry night by vincent van gogh") api_key = st.text_input("Enter your Stability AI API key:") if prompt and api_key: # Set up our connection to the API. os.environ['STABILITY_KEY'] = api_key stability_api = client.StabilityInference( key=os.environ['STABILITY_KEY'], # API Key reference. verbose=True, # Print debug messages. engine="stable-diffusion-xl-1024-v1-0", # Set the engine to use for generation. ) style_preset_selector = st.sidebar.selectbox("Select Style Preset",["3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", "modeling-compound", "neon-punk", "origami", "photographic", "pixel-art", "tile-texture"],index = 5) if st.sidebar.toggle("Proceed to Generate Image"): # Set up our initial generation parameters. answers2 = stability_api.generate( prompt=prompt, init_image=resized_image, # Assign our uploaded image as our Initial Image for transformation. start_schedule=0.6, steps=250, cfg_scale=10.0, width=width, height=height, sampler=generation.SAMPLER_K_DPMPP_SDE, style_preset=style_preset_selector ) # Process the response from the API for resp in answers2: for artifact in resp.artifacts: if artifact.finish_reason == generation.FILTER: warnings.warn( "Your request activated the API's safety filters and could not be processed." "Please modify the prompt and try again.") if artifact.type == generation.ARTIFACT_IMAGE: img2 = Image.open(io.BytesIO(artifact.binary)) # Display the generated image st.image(img2, caption="Generated Image", use_column_width=True) # Combine the generated image with the original image using the mask combined_img = image_mask_gen.combine_mask_and_inverse_gen(original_image, img2, masks[0]) st.image(combined_img, caption="Combined Image", use_column_width=True)