Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageDraw | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
from openai import OpenAI | |
import os | |
import base64 | |
import io | |
import requests | |
import numpy as np | |
from scipy import ndimage | |
from insightface.app import FaceAnalysis | |
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY') | |
IDEOGRAM_URL = "https://api.ideogram.ai/edit" | |
face_detection_app = FaceAnalysis(allowed_modules=['detection']) # enable detection model only | |
face_detection_app.prepare(ctx_id=0, det_size=(640, 640)) | |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
# Constants should be in UPPERCASE | |
GPT_MODEL_NAME = "gpt-4o" | |
GPT_MAX_TOKENS = 500 | |
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
torch.set_float32_matmul_precision(['high', 'highest'][0]) | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
model.eval() | |
GPT_PROMPT = ''' | |
You are a background editor. | |
Your job is to adjust the background of the image to be in a {{holiday}} vibes, but take into considration the perspective and the logic of the image. | |
Your output should be a prompt that can be used to edit the background of the image. | |
The background should be edited in a way that is consistent with the image. | |
The prompt should not include any text or writing in the background. | |
''' | |
def image_to_prompt(image: str, holiday: str) -> tuple[str, str]: | |
base64_image = encode_image(image) | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": GPT_PROMPT.replace("{{holiday}}", holiday)}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
] | |
}] | |
response = client.chat.completions.create( | |
model=GPT_MODEL_NAME, | |
messages=messages, | |
max_tokens=GPT_MAX_TOKENS | |
) | |
full_response = response.choices[0].message.content | |
return full_response | |
def encode_image(image: Image.Image) -> str: | |
"""Convert a PIL Image to base64 encoded string. | |
Args: | |
image (PIL.Image.Image): The PIL Image to encode | |
Returns: | |
str: Base64 encoded image string | |
""" | |
# Create a temporary buffer to save the image | |
buffer = io.BytesIO() | |
# Save the image as PNG to the buffer | |
image.save(buffer, format='PNG') | |
# Get the bytes from the buffer and encode to base64 | |
return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
def remove_background(input_image): | |
image_size = (1024, 1024) | |
# Transform the input image | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Process the image | |
input_tensor = transform_image(input_image).unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_tensor = input_tensor.to('cuda') | |
# Generate prediction | |
with torch.no_grad(): | |
preds = model(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(input_image.size) | |
# Create image without background | |
result_image = input_image.copy() | |
result_image.putalpha(mask) | |
# Create image with only background | |
only_background_image = input_image.copy() | |
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask | |
only_background_image.putalpha(inverted_mask) | |
return result_image, only_background_image, mask | |
def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image: | |
# Convert PIL images to bytes | |
image_buffer = io.BytesIO() | |
image.save(image_buffer, format='PNG') | |
image_bytes = image_buffer.getvalue() | |
mask_buffer = io.BytesIO() | |
mask.save(mask_buffer, format='PNG') | |
mask_bytes = mask_buffer.getvalue() | |
# Create the files dictionary with actual bytes data | |
files = { | |
"image_file": ("image.png", image_bytes, "image/png"), | |
"mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file | |
} | |
prevent_text_in_background = "Do not include any text or writing in the background." | |
prompt = f"{prompt} {prevent_text_in_background}" | |
payload = { | |
"prompt": prompt, # Use the actual prompt parameter | |
"model": "V_2", | |
"magic_prompt_option": "ON", | |
"num_images": 1, | |
"style_type": "REALISTIC" | |
} | |
headers = {"Api-Key": IDEOGRAM_API_KEY} | |
response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers) | |
if response.status_code == 200: | |
# Assuming the API returns an image in the response | |
response_data = response.json() | |
# You'll need to handle the response according to Ideogram's API specification | |
# This is a placeholder - adjust according to actual API response format | |
result_image_url = response_data.get('data')[0].get('url') | |
if result_image_url: | |
result_response = requests.get(result_image_url) | |
return Image.open(io.BytesIO(result_response.content)) | |
raise Exception(f"Failed to modify background: {response.text}") | |
def dilate_mask(mask: Image.Image) -> Image.Image: | |
# Convert mask to numpy array | |
mask_array = np.array(mask) | |
# Apply maximum filter using scipy.ndimage | |
dilated_mask = ndimage.maximum_filter(mask_array, size=20) | |
# Convert back to PIL Image | |
return Image.fromarray(dilated_mask.astype(np.uint8)) | |
def detect_faces(image: Image.Image) -> list[dict]: | |
# Convert PIL Image to numpy array | |
image_np = np.array(image) | |
faces = face_detection_app.get(image_np) | |
return faces | |
def check_text_position(x, y, text_rect_width, text_rect_height, face_rects, image_width, image_height): | |
# Calculate text rectangle bounds | |
text_x1 = x - text_rect_width//2 | |
text_y1 = y - text_rect_height//2 | |
text_x2 = x + text_rect_width//2 | |
text_y2 = y + text_rect_height//2 | |
# Check if text is within image bounds | |
if (text_x1 < 0 or text_x2 > image_width or | |
text_y1 < 0 or text_y2 > image_height): | |
return False | |
# Check for collision with any face | |
for face_rect in face_rects: | |
fx1, fy1, fx2, fy2 = face_rect | |
# Check if rectangles overlap | |
if not (text_x2 < fx1 or text_x1 > fx2 or text_y2 < fy1 or text_y1 > fy2): | |
return False | |
return True | |
def find_place_to_add_text(image: Image.Image, faces: list[dict]) -> tuple[int, int]: | |
image_width, image_height = image.size | |
# Convert face coordinates to rectangles for collision detection | |
face_rects = [] | |
padding = 20 # Padding around faces | |
for face in faces: | |
bbox = face.bbox # Get bounding box coordinates | |
x1, y1, x2, y2 = map(int, bbox) | |
face_rects.append(( | |
max(0, x1-padding), | |
max(0, y1-padding), | |
min(image_width, x2+padding), | |
min(image_height, y2+padding) | |
)) | |
# Define possible text positions | |
padding_x = int(0.1 * image_width) | |
padding_y = int(0.1 * image_height) | |
positions = [ | |
(image_width//2, int(0.85*image_height) - padding_y), # Bottom center | |
(image_width//2, int(0.15*image_height) + padding_y), # Top center | |
(int(0.15*image_width) + padding_x, image_height//2), # Left middle | |
(int(0.85*image_width) - padding_x, image_height//2) # Right middle | |
] | |
# Start with largest desired text size and gradually reduce | |
current_text_width = 0.8 | |
current_text_height = 0.3 | |
min_text_width = 0.1 | |
min_text_height = 0.03 | |
reduction_factor = 0.9 # Reduce size by 10% each iteration | |
while current_text_width >= min_text_width and current_text_height >= min_text_height: | |
text_rect_width = current_text_width * image_width | |
text_rect_height = current_text_height * image_height | |
# Try each position with current size | |
for x, y in positions: | |
if check_text_position(x, y, text_rect_width, text_rect_height, | |
face_rects, image_width, image_height): | |
top_left_x_in_percent = (x - text_rect_width//2) / image_width | |
top_left_y_in_percent = (y - text_rect_height//2) / image_height | |
return top_left_x_in_percent, top_left_y_in_percent, current_text_width, current_text_height | |
# If no position works, reduce text size and try again | |
current_text_width *= reduction_factor | |
current_text_height *= reduction_factor | |
# If we get here, return bottom center with minimum size as fallback | |
print("Failed to find a suitable position") | |
# Return bottom center with minimum size as fallback | |
return ( | |
(image_width//2 - (min_text_width * image_width)//2) / image_width, # x position | |
(int(0.85*image_height) - (min_text_height * image_height)//2) / image_height, # y position | |
min_text_width, # width | |
min_text_height # height | |
) | |
def crop_to_ratio_while_preventing_faces(image: Image.Image, faces: list[dict]) -> Image.Image: | |
ASPECT_RATIO_PORTRAIT = 5/7 | |
ASPECT_RATIO_LANDSCAPE = 7/5 | |
image_width, image_height = image.size | |
# Calculate current aspect ratio | |
current_ratio = image_width / image_height | |
is_portrait = current_ratio < 1 | |
target_ratio = ASPECT_RATIO_PORTRAIT if is_portrait else ASPECT_RATIO_LANDSCAPE | |
# Calculate new dimensions | |
if current_ratio > target_ratio: | |
new_width = int(image_height * target_ratio) | |
new_height = image_height | |
else: | |
new_width = image_width | |
new_height = int(image_width / target_ratio) | |
# If no faces, just do center crop | |
if not faces: | |
x = (image_width - new_width) // 2 | |
y = (image_height - new_height) // 2 | |
return image.crop((x, y, x + new_width, y + new_height)) | |
# Find the bounding box that contains all faces | |
face_x1 = min(int(face['bbox'][0]) for face in faces) | |
face_y1 = min(int(face['bbox'][1]) for face in faces) | |
face_x2 = max(int(face['bbox'][2]) for face in faces) | |
face_y2 = max(int(face['bbox'][3]) for face in faces) | |
# Add padding around faces | |
padding = 50 | |
face_x1 = max(0, face_x1 - padding) | |
face_y1 = max(0, face_y1 - padding) | |
face_x2 = min(image_width, face_x2 + padding) | |
face_y2 = min(image_height, face_y2 + padding) | |
# Calculate crop coordinates that ensure faces are included | |
x = max(0, min(face_x1, image_width - new_width)) | |
y = max(0, min(face_y1, image_height - new_height)) | |
# Adjust if faces would be cut off | |
if x + new_width < face_x2: | |
x = max(0, face_x2 - new_width) | |
if y + new_height < face_y2: | |
y = max(0, face_y2 - new_height) | |
return image.crop((x, y, x + new_width, y + new_height)) | |
def run_flow(input_image, holiday, message): | |
faces = detect_faces(input_image) | |
cropped_image = crop_to_ratio_while_preventing_faces(input_image, faces) | |
prompt = image_to_prompt(cropped_image, holiday) | |
print(prompt) | |
result_image, only_background_image, mask = remove_background(cropped_image) | |
dilated_mask = dilate_mask(mask) | |
output_image = modify_background(cropped_image, dilated_mask, prompt) | |
# Create a copy of the modified image before drawing | |
output_image_with_text_rectangle = output_image.copy() | |
text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent = find_place_to_add_text(cropped_image, faces) | |
text_x = text_x_in_percent * output_image.width | |
text_y = text_y_in_percent * output_image.height | |
text_width = text_width_in_percent * output_image.width | |
text_height = text_height_in_percent * output_image.height | |
draw = ImageDraw.Draw(output_image_with_text_rectangle) | |
draw.rectangle((text_x, text_y, text_x + text_width, text_y + text_height), outline="red") | |
# Return the actual images, not the ImageDraw object | |
return output_image, output_image_with_text_rectangle, text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent | |
# Replace the demo interface | |
demo = gr.Interface( | |
fn=run_flow, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"), | |
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Output Image"), | |
gr.Image(type="pil", label="Output Image With Text Rectangle"), | |
gr.Number(label="Text Top Left X"), | |
gr.Number(label="Text Top Left Y"), | |
gr.Number(label="Text Width"), | |
gr.Number(label="Text Height") | |
], | |
title="Holiday Card Generator", | |
description="Upload an image to generate a holiday card" | |
) | |
demo.launch() | |