|
import streamlit as st
|
|
from streamlit_drawable_canvas import st_canvas
|
|
from PIL import Image
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import image_mask_gen
|
|
import torch
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
import os
|
|
import io
|
|
import warnings
|
|
from stability_sdk import client
|
|
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
|
|
|
|
import streamlit as st
|
|
import base64
|
|
|
|
|
|
|
|
def show_points(coords, labels, ax, marker_size=375):
|
|
pos_points = coords[labels == 1]
|
|
neg_points = coords[labels == 0]
|
|
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
|
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
|
|
|
def remove_duplicates(coords, labels):
|
|
unique_coords = []
|
|
unique_labels = []
|
|
seen = set()
|
|
|
|
for coord, label in zip(coords, labels):
|
|
coord_tuple = tuple(coord)
|
|
if coord_tuple not in seen:
|
|
seen.add(coord_tuple)
|
|
unique_coords.append(coord)
|
|
unique_labels.append(label)
|
|
|
|
return unique_coords, unique_labels
|
|
|
|
|
|
def image_augmentation_page():
|
|
pass
|
|
st.title("Image Augmentation")
|
|
st.write("Upload an image to apply augmentation techniques.")
|
|
|
|
|
|
if "inclusive_points" not in st.session_state:
|
|
st.session_state.inclusive_points = []
|
|
if "exclusive_points" not in st.session_state:
|
|
st.session_state.exclusive_points = []
|
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
|
|
|
|
if uploaded_file is not None:
|
|
|
|
image = Image.open(uploaded_file)
|
|
|
|
|
|
max_display_width = 700
|
|
|
|
|
|
scale_factor = min(max_display_width / image.size[0], 1)
|
|
|
|
|
|
display_width = int(image.size[0] * scale_factor)
|
|
display_height = int(image.size[1] * scale_factor)
|
|
resized_image = image.resize((display_width, display_height))
|
|
|
|
|
|
st.subheader("Select Inclusive Points (Green)")
|
|
canvas_inclusive = st_canvas(
|
|
fill_color="rgba(0, 0, 0, 0)",
|
|
stroke_width=1,
|
|
stroke_color="blue",
|
|
background_image=resized_image,
|
|
update_streamlit=True,
|
|
height=display_height,
|
|
width=display_width,
|
|
drawing_mode="circle",
|
|
point_display_radius=3,
|
|
key="canvas_inclusive"
|
|
)
|
|
|
|
|
|
if canvas_inclusive.json_data is not None:
|
|
objects = canvas_inclusive.json_data["objects"]
|
|
new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects]
|
|
st.session_state.inclusive_points.extend(new_clicks)
|
|
|
|
|
|
fig_inclusive, ax = plt.subplots()
|
|
ax.imshow(image)
|
|
ax.axis('off')
|
|
|
|
|
|
inclusive_points = np.array(st.session_state.inclusive_points)
|
|
labels_inclusive = np.array([1] * len(st.session_state.inclusive_points))
|
|
|
|
|
|
if len(inclusive_points) > 0:
|
|
show_points(inclusive_points, labels_inclusive, ax)
|
|
|
|
st.pyplot(fig_inclusive)
|
|
|
|
|
|
st.divider()
|
|
|
|
|
|
st.subheader("Select Exclusive Points (Red)")
|
|
canvas_exclusive = st_canvas(
|
|
fill_color="rgba(0, 0, 0, 0)",
|
|
stroke_width=1,
|
|
stroke_color="blue",
|
|
background_image=resized_image,
|
|
update_streamlit=True,
|
|
height=display_height,
|
|
width=display_width,
|
|
drawing_mode="circle",
|
|
point_display_radius=3,
|
|
key="canvas_exclusive"
|
|
)
|
|
|
|
|
|
if canvas_exclusive.json_data is not None:
|
|
objects = canvas_exclusive.json_data["objects"]
|
|
new_clicks = [[(obj["left"] + obj["radius"]) / scale_factor, (obj["top"] + obj["radius"]) / scale_factor] for obj in objects]
|
|
st.session_state.exclusive_points.extend(new_clicks)
|
|
|
|
|
|
fig_exclusive, ax = plt.subplots()
|
|
ax.imshow(image)
|
|
ax.axis('off')
|
|
|
|
|
|
exclusive_points = np.array(st.session_state.exclusive_points)
|
|
labels_exclusive = np.array([0] * len(st.session_state.exclusive_points))
|
|
|
|
|
|
if len(exclusive_points) > 0:
|
|
show_points(exclusive_points, labels_exclusive, ax)
|
|
|
|
st.pyplot(fig_exclusive)
|
|
|
|
|
|
coordinates = st.session_state.inclusive_points + st.session_state.exclusive_points
|
|
labels = [1] * len(st.session_state.inclusive_points) + [0] * len(st.session_state.exclusive_points)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button("Clear All Points"):
|
|
st.session_state.inclusive_points = []
|
|
st.session_state.exclusive_points = []
|
|
|
|
unique_coordinates, unique_labels = remove_duplicates(coordinates, labels)
|
|
|
|
st.write("Unique Coordinates:", tuple(unique_coordinates))
|
|
st.write("Unique Labels:", tuple(unique_labels))
|
|
|
|
|
|
sam2_checkpoint = "sam2_hiera_base_plus.pt"
|
|
model_cfg = "sam2_hiera_b+.yaml"
|
|
|
|
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
|
|
|
predictor = SAM2ImagePredictor(sam2_model)
|
|
|
|
image = image
|
|
predictor.set_image(image)
|
|
|
|
input_point = np.array(unique_coordinates)
|
|
input_label = np.array(unique_labels)
|
|
|
|
masks, scores, logits = predictor.predict(
|
|
point_coords=input_point,
|
|
point_labels=input_label,
|
|
multimask_output=True,
|
|
)
|
|
sorted_ind = np.argsort(scores)[::-1]
|
|
masks = masks[sorted_ind]
|
|
scores = scores[sorted_ind]
|
|
logits = logits[sorted_ind]
|
|
|
|
mask_input = logits[np.argmax(scores), :, :]
|
|
|
|
masks, scores, _ = predictor.predict(
|
|
point_coords=input_point,
|
|
point_labels=input_label,
|
|
mask_input=mask_input[None, :, :],
|
|
multimask_output=False,
|
|
)
|
|
image_mask_gen.show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)
|
|
|
|
|
|
|
|
original_image = Image.open(uploaded_file)
|
|
|
|
|
|
with st.container(border=True):
|
|
col1, col2 = st.columns(2)
|
|
with col1:
|
|
mask_images = image_mask_gen.show_masks_1(original_image, masks, scores)
|
|
for idx, (img, score) in enumerate(mask_images):
|
|
st.image(img, caption=f'Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
|
|
with col2:
|
|
inverse_mask_images = image_mask_gen.show_inverse_masks(original_image, masks, scores)
|
|
for idx, (img, score) in enumerate(inverse_mask_images):
|
|
st.image(img, caption=f'Inverse Mask {idx+1}, Score: {score:.3f}', use_column_width=True)
|
|
|
|
if st.checkbox("Proceed to Image Augmentation"):
|
|
|
|
image_aug_select = st.sidebar.selectbox("Select Augmentation for Mask",["Pixelate","Hue Change","Mask Replacement","Generative Img2Img"])
|
|
if image_aug_select == "Pixelate":
|
|
|
|
if st.sidebar.toggle("Proceed to Pixelate Mask"):
|
|
pixelation_level = st.slider("Select Pixelation Level", min_value=5, max_value=50, value=10)
|
|
combined_image = image_mask_gen.combine_pixelated_mask(original_image, masks[0], pixelation_level)
|
|
st.image(combined_image, caption="Combined Pixelated Image", use_column_width=True)
|
|
elif image_aug_select == "Hue Change":
|
|
|
|
if st.sidebar.toggle("Proceed to Hue Change"):
|
|
|
|
hue_shift = st.slider("Select Hue Shift", min_value=-180, max_value=180, value=0)
|
|
|
|
combined_image = image_mask_gen.combine_hue_changed_mask(original_image, masks[0], hue_shift)
|
|
st.image(combined_image, caption="Combined Hue Changed Image", use_column_width=True)
|
|
elif image_aug_select == "Mask Replacement":
|
|
|
|
if st.sidebar.toggle("Proceed to replace Mask"):
|
|
replacement_file = st.file_uploader("Upload the replacement image", type=["png", "jpg", "jpeg"])
|
|
if replacement_file is not None:
|
|
replacement_image = Image.open(replacement_file)
|
|
combined_image = image_mask_gen.combine_mask_replaced_image(original_image, replacement_image, masks[0])
|
|
st.image(combined_image, caption="Masked Area Replaced Image", use_column_width=True)
|
|
elif image_aug_select == "Generative Img2Img":
|
|
|
|
msk_img = None
|
|
mask_images_x = image_mask_gen.show_masks_1(original_image, masks, scores)
|
|
for idx, (img, score) in enumerate(mask_images_x):
|
|
msk_img = img
|
|
|
|
|
|
rgb_image = msk_img.convert("RGB")
|
|
|
|
resized_image = image_mask_gen.resize_image(rgb_image)
|
|
|
|
width, height = resized_image.size
|
|
|
|
|
|
prompt = st.text_input("Enter your prompt:", "A Beautiful day, in the style reference of starry night by vincent van gogh")
|
|
api_key = st.text_input("Enter your Stability AI API key:")
|
|
|
|
if prompt and api_key:
|
|
|
|
os.environ['STABILITY_KEY'] = api_key
|
|
stability_api = client.StabilityInference(
|
|
key=os.environ['STABILITY_KEY'],
|
|
verbose=True,
|
|
engine="stable-diffusion-xl-1024-v1-0",
|
|
)
|
|
style_preset_selector = st.sidebar.selectbox("Select Style Preset",["3d-model", "analog-film", "anime", "cinematic", "comic-book", "digital-art", "enhance", "fantasy-art", "isometric", "line-art", "low-poly", "modeling-compound", "neon-punk",
|
|
"origami", "photographic", "pixel-art", "tile-texture"],index = 5)
|
|
if st.sidebar.toggle("Proceed to Generate Image"):
|
|
|
|
answers2 = stability_api.generate(
|
|
prompt=prompt,
|
|
init_image=resized_image,
|
|
start_schedule=0.6,
|
|
steps=250,
|
|
cfg_scale=10.0,
|
|
width=width,
|
|
height=height,
|
|
sampler=generation.SAMPLER_K_DPMPP_SDE,
|
|
style_preset=style_preset_selector
|
|
)
|
|
|
|
|
|
for resp in answers2:
|
|
for artifact in resp.artifacts:
|
|
if artifact.finish_reason == generation.FILTER:
|
|
warnings.warn(
|
|
"Your request activated the API's safety filters and could not be processed."
|
|
"Please modify the prompt and try again.")
|
|
if artifact.type == generation.ARTIFACT_IMAGE:
|
|
img2 = Image.open(io.BytesIO(artifact.binary))
|
|
|
|
st.image(img2, caption="Generated Image", use_column_width=True)
|
|
|
|
|
|
combined_img = image_mask_gen.combine_mask_and_inverse_gen(original_image, img2, masks[0])
|
|
st.image(combined_img, caption="Combined Image", use_column_width=True) |