Spaces:
Running
Running
import torch | |
import diffusers | |
import tqdm as notebook_tqdm | |
from diffusers import StableDiffusionInpaintPipeline | |
import cv2 | |
import math | |
import gradio as gr | |
import numpy as np | |
import os | |
import mediapipe as mp | |
from mediapipe.tasks import python | |
from mediapipe.tasks.python import vision | |
from mediapipe.tasks.python.components import containers | |
from skimage.measure import label, regionprops | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
from skimage.measure import label | |
from skimage.measure import regionprops | |
from PIL import Image | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageDraw | |
import mediapipe as mp | |
from transformers import pipeline | |
from skimage.measure import label, regionprops | |
import gradio as gr | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageDraw | |
import mediapipe as mp | |
from transformers import pipeline | |
from skimage.measure import label, regionprops | |
import matplotlib.pyplot as plt | |
def _normalized_to_pixel_coordinates( | |
normalized_x: float, normalized_y: float, image_width: int, image_height: int): | |
"""Converts normalized value pair to pixel coordinates.""" | |
# Checks if the float value is between 0 and 1. | |
def is_valid_normalized_value(value: float) -> bool: | |
return (value > 0 or math.isclose(0, value)) and (value < 1 or math.isclose(1, value)) | |
if not (is_valid_normalized_value(normalized_x) and is_valid_normalized_value(normalized_y)): | |
# TODO: Draw coordinates even if it's outside of the image bounds. | |
return None | |
x_px = min(math.floor(normalized_x * image_width), image_width - 1) | |
y_px = min(math.floor(normalized_y * image_height), image_height - 1) | |
return x_px, y_px | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16, | |
).to(device) | |
#from huggingface_hub import login | |
#login() | |
#pipe2 = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) | |
#pipe2.to("cuda") | |
BG_COLOR = (192, 192, 192) # gray | |
MASK_COLOR = (255, 255, 255) # white | |
RegionOfInterest = vision.InteractiveSegmenterRegionOfInterest | |
NormalizedKeypoint = containers.keypoint.NormalizedKeypoint | |
# Create the options that will be used for InteractiveSegmenter | |
base_options = python.BaseOptions(model_asset_path='model.tflite') | |
options = vision.ImageSegmenterOptions(base_options=base_options, output_category_mask=True) | |
def create_bounding_box_mask(image): | |
image = 1 - image | |
# Find the coordinates of the non-background pixels | |
y_indices, x_indices = np.nonzero(image) | |
if not y_indices.size or not x_indices.size: | |
return None # No areas found, you might return an empty mask or raise an error | |
# Calculate the bounding box coordinates | |
x_min, x_max = x_indices.min(), x_indices.max() | |
y_min, y_max = y_indices.min(), y_indices.max() | |
# Create a new mask for the bounding box | |
bounding_mask = np.zeros_like(image, dtype=np.uint8) # Ensure it's a single-channel mask | |
bounding_mask[y_min:y_max+1, x_min:x_max+1] = 1 # Fill the bounding box with white 1 | |
return bounding_mask | |
def segment_2(image_np, coordinates): | |
OVERLAY_COLOR = (255, 105, 180) # Rose | |
# Créer le segmenteur | |
with python.vision.InteractiveSegmenter.create_from_options(options) as segmenter: | |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_np) | |
# Enlever les parenthèses | |
coordinates = coordinates.strip("()") | |
# Séparer les valeurs par la virgule | |
valeurs = coordinates.split(',') | |
# Convertir les chaînes de caractères en nombres flottants | |
x = float(valeurs[0]) | |
y = float(valeurs[1]) | |
# Récupérer les masques de catégorie pour l'image | |
roi = RegionOfInterest(format=RegionOfInterest.Format.KEYPOINT, | |
keypoint=NormalizedKeypoint(x, y)) | |
segmentation_result = segmenter.segment(image, roi) | |
category_mask = segmentation_result.category_mask | |
# Trouver la boîte englobante de la région segmentée | |
mask = (category_mask.numpy_view().astype(np.uint8)*255) | |
# Trouver la boîte englobante de la région segmentée | |
x, y, w, h = cv2.boundingRect(mask) | |
# Convertir l'image BGR en RGB | |
image_data = cv2.cvtColor(image.numpy_view(), cv2.COLOR_BGR2RGB) | |
# Créer une image d'incrustation avec la couleur désirée (par exemple, (255, 0, 0) pour le rouge) | |
overlay_image = np.zeros(image_data.shape, dtype=np.uint8) | |
overlay_image[:] = OVERLAY_COLOR | |
# Créer la condition à partir du tableau category_masks | |
alpha = np.stack((category_mask.numpy_view(),) * 3, axis=-1) <= 0.1 | |
# Créer un canal alpha à partir de la condition avec l'opacité désirée (par exemple, 0.7 pour 70%) | |
alpha = alpha.astype(float) * 0.5 # Réduire l'opacité à 50% | |
# Fusionner l'image originale et l'image d'incrustation en fonction du canal alpha | |
output_image = image_data * (1 - alpha) + overlay_image * alpha | |
output_image = output_image.astype(np.uint8) | |
# Dessiner un point blanc avec une bordure noire pour indiquer le point d'intérêt | |
thickness, radius = 6, -1 | |
keypoint_px = _normalized_to_pixel_coordinates(x, y, image.width, image.height) | |
cv2.circle(output_image, keypoint_px, thickness + 5, (0, 0, 0), radius) | |
cv2.circle(output_image, keypoint_px, thickness, (255, 255, 255), radius) | |
image_width, image_height = output_image.shape[:2] | |
bounding_mask = create_bounding_box_mask(mask) | |
bbox_mask_image = Image.fromarray((bounding_mask * 255).astype(np.uint8)) | |
bbox_img = bbox_mask_image.convert("RGB") | |
bbox_img.resize((image_width, image_height)) | |
return output_image,bbox_mask_image | |
def generate_2(image_file_path, bbox_image, prompt): | |
# Read image | |
img = Image.fromarray(image_file_path).convert("RGB") | |
# Generate images using images and prompts | |
images = pipe(prompt=prompt, | |
image=img, | |
mask_image=bbox_image, | |
generator=torch.Generator(device="cuda").manual_seed(0), | |
num_images_per_prompt=3, | |
plms=True).images | |
# Create an image grid | |
def image_grid(imgs, rows, cols): | |
assert len(imgs) == rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
return grid | |
grid_image = image_grid(images, 1, 3) | |
return grid_image | |
def onclick(evt: gr.SelectData, image): | |
if evt: | |
x, y = evt.index | |
# Normalize the coordinates by 0-1 | |
normalized_x = round(x / image.shape[1], 2) | |
normalized_y = round(y / image.shape[0], 2) | |
return normalized_x, normalized_y | |
else: | |
return None, None | |
# Assurez-vous d'importer ou de définir les fonctions segment et generate_2 ici | |
def callback(image, coordinates, prompt): | |
# Convertir l'image PIL en chemin de fichier temporaire ou en numpy array si nécessaire | |
# Appeler la fonction segment avec les coordonnées et l'image | |
segmented_image, bbox_image = segment_2(image, coordinates) | |
# Appeler la fonction generate_2 avec l'image, bbox_image, et le prompt | |
grid_image = generate_2(image, bbox_image, prompt) | |
# Retourner les images résultantes pour l'affichage | |
return segmented_image, grid_image | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image_input = gr.Image(type="numpy", label="Upload Image", interactive=True) | |
coordinates_output = gr.Textbox(label="Coordinates") | |
with gr.Row(): | |
prompt_input = gr.Textbox(label="What do you want to change?") | |
submit_button = gr.Button("Submit") | |
with gr.Row(): | |
segmented_image_output = gr.Image(type="numpy", label="Segmented Image") | |
grid_image_output = gr.Image(type="pil", label="Generated Image Grid") | |
image_input.select(onclick, inputs=[image_input], outputs=coordinates_output) | |
submit_button.click(fn=callback, inputs=[image_input, coordinates_output, prompt_input], outputs=[segmented_image_output, grid_image_output]) | |
demo.launch(debug=True) | |