holiday_cards / app.py
Amit Gazal
prevent text in the image
522f227
import gradio as gr
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from openai import OpenAI
import os
import base64
import io
import requests
import numpy as np
from scipy import ndimage
from insightface.app import FaceAnalysis
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY')
IDEOGRAM_URL = "https://api.ideogram.ai/edit"
face_detection_app = FaceAnalysis(allowed_modules=['detection']) # enable detection model only
face_detection_app.prepare(ctx_id=0, det_size=(640, 640))
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
# Constants should be in UPPERCASE
GPT_MODEL_NAME = "gpt-4o"
GPT_MAX_TOKENS = 500
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
if torch.cuda.is_available():
model = model.to('cuda')
model.eval()
GPT_PROMPT = '''
You are a background editor.
Your job is to adjust the background of the image to be in a {{holiday}} vibes, but take into considration the perspective and the logic of the image.
Your output should be a prompt that can be used to edit the background of the image.
The background should be edited in a way that is consistent with the image.
The prompt should not include any text or writing in the background.
'''
def image_to_prompt(image: str, holiday: str) -> tuple[str, str]:
base64_image = encode_image(image)
messages = [{
"role": "user",
"content": [
{"type": "text", "text": GPT_PROMPT.replace("{{holiday}}", holiday)},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]
}]
response = client.chat.completions.create(
model=GPT_MODEL_NAME,
messages=messages,
max_tokens=GPT_MAX_TOKENS
)
full_response = response.choices[0].message.content
return full_response
def encode_image(image: Image.Image) -> str:
"""Convert a PIL Image to base64 encoded string.
Args:
image (PIL.Image.Image): The PIL Image to encode
Returns:
str: Base64 encoded image string
"""
# Create a temporary buffer to save the image
buffer = io.BytesIO()
# Save the image as PNG to the buffer
image.save(buffer, format='PNG')
# Get the bytes from the buffer and encode to base64
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def remove_background(input_image):
image_size = (1024, 1024)
# Transform the input image
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Process the image
input_tensor = transform_image(input_image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.to('cuda')
# Generate prediction
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
# Create image without background
result_image = input_image.copy()
result_image.putalpha(mask)
# Create image with only background
only_background_image = input_image.copy()
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
only_background_image.putalpha(inverted_mask)
return result_image, only_background_image, mask
def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image:
# Convert PIL images to bytes
image_buffer = io.BytesIO()
image.save(image_buffer, format='PNG')
image_bytes = image_buffer.getvalue()
mask_buffer = io.BytesIO()
mask.save(mask_buffer, format='PNG')
mask_bytes = mask_buffer.getvalue()
# Create the files dictionary with actual bytes data
files = {
"image_file": ("image.png", image_bytes, "image/png"),
"mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file
}
prevent_text_in_background = "Do not include any text or writing in the background."
prompt = f"{prompt} {prevent_text_in_background}"
payload = {
"prompt": prompt, # Use the actual prompt parameter
"model": "V_2",
"magic_prompt_option": "ON",
"num_images": 1,
"style_type": "REALISTIC"
}
headers = {"Api-Key": IDEOGRAM_API_KEY}
response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers)
if response.status_code == 200:
# Assuming the API returns an image in the response
response_data = response.json()
# You'll need to handle the response according to Ideogram's API specification
# This is a placeholder - adjust according to actual API response format
result_image_url = response_data.get('data')[0].get('url')
if result_image_url:
result_response = requests.get(result_image_url)
return Image.open(io.BytesIO(result_response.content))
raise Exception(f"Failed to modify background: {response.text}")
def dilate_mask(mask: Image.Image) -> Image.Image:
# Convert mask to numpy array
mask_array = np.array(mask)
# Apply maximum filter using scipy.ndimage
dilated_mask = ndimage.maximum_filter(mask_array, size=20)
# Convert back to PIL Image
return Image.fromarray(dilated_mask.astype(np.uint8))
def detect_faces(image: Image.Image) -> list[dict]:
# Convert PIL Image to numpy array
image_np = np.array(image)
faces = face_detection_app.get(image_np)
return faces
def check_text_position(x, y, text_rect_width, text_rect_height, face_rects, image_width, image_height):
# Calculate text rectangle bounds
text_x1 = x - text_rect_width//2
text_y1 = y - text_rect_height//2
text_x2 = x + text_rect_width//2
text_y2 = y + text_rect_height//2
# Check if text is within image bounds
if (text_x1 < 0 or text_x2 > image_width or
text_y1 < 0 or text_y2 > image_height):
return False
# Check for collision with any face
for face_rect in face_rects:
fx1, fy1, fx2, fy2 = face_rect
# Check if rectangles overlap
if not (text_x2 < fx1 or text_x1 > fx2 or text_y2 < fy1 or text_y1 > fy2):
return False
return True
def find_place_to_add_text(image: Image.Image, faces: list[dict]) -> tuple[int, int]:
image_width, image_height = image.size
# Convert face coordinates to rectangles for collision detection
face_rects = []
padding = 20 # Padding around faces
for face in faces:
bbox = face.bbox # Get bounding box coordinates
x1, y1, x2, y2 = map(int, bbox)
face_rects.append((
max(0, x1-padding),
max(0, y1-padding),
min(image_width, x2+padding),
min(image_height, y2+padding)
))
# Define possible text positions
padding_x = int(0.1 * image_width)
padding_y = int(0.1 * image_height)
positions = [
(image_width//2, int(0.85*image_height) - padding_y), # Bottom center
(image_width//2, int(0.15*image_height) + padding_y), # Top center
(int(0.15*image_width) + padding_x, image_height//2), # Left middle
(int(0.85*image_width) - padding_x, image_height//2) # Right middle
]
# Start with largest desired text size and gradually reduce
current_text_width = 0.8
current_text_height = 0.3
min_text_width = 0.1
min_text_height = 0.03
reduction_factor = 0.9 # Reduce size by 10% each iteration
while current_text_width >= min_text_width and current_text_height >= min_text_height:
text_rect_width = current_text_width * image_width
text_rect_height = current_text_height * image_height
# Try each position with current size
for x, y in positions:
if check_text_position(x, y, text_rect_width, text_rect_height,
face_rects, image_width, image_height):
top_left_x_in_percent = (x - text_rect_width//2) / image_width
top_left_y_in_percent = (y - text_rect_height//2) / image_height
return top_left_x_in_percent, top_left_y_in_percent, current_text_width, current_text_height
# If no position works, reduce text size and try again
current_text_width *= reduction_factor
current_text_height *= reduction_factor
# If we get here, return bottom center with minimum size as fallback
print("Failed to find a suitable position")
# Return bottom center with minimum size as fallback
return (
(image_width//2 - (min_text_width * image_width)//2) / image_width, # x position
(int(0.85*image_height) - (min_text_height * image_height)//2) / image_height, # y position
min_text_width, # width
min_text_height # height
)
def crop_to_ratio_while_preventing_faces(image: Image.Image, faces: list[dict]) -> Image.Image:
ASPECT_RATIO_PORTRAIT = 5/7
ASPECT_RATIO_LANDSCAPE = 7/5
image_width, image_height = image.size
# Calculate current aspect ratio
current_ratio = image_width / image_height
is_portrait = current_ratio < 1
target_ratio = ASPECT_RATIO_PORTRAIT if is_portrait else ASPECT_RATIO_LANDSCAPE
# Calculate new dimensions
if current_ratio > target_ratio:
new_width = int(image_height * target_ratio)
new_height = image_height
else:
new_width = image_width
new_height = int(image_width / target_ratio)
# If no faces, just do center crop
if not faces:
x = (image_width - new_width) // 2
y = (image_height - new_height) // 2
return image.crop((x, y, x + new_width, y + new_height))
# Find the bounding box that contains all faces
face_x1 = min(int(face['bbox'][0]) for face in faces)
face_y1 = min(int(face['bbox'][1]) for face in faces)
face_x2 = max(int(face['bbox'][2]) for face in faces)
face_y2 = max(int(face['bbox'][3]) for face in faces)
# Add padding around faces
padding = 50
face_x1 = max(0, face_x1 - padding)
face_y1 = max(0, face_y1 - padding)
face_x2 = min(image_width, face_x2 + padding)
face_y2 = min(image_height, face_y2 + padding)
# Calculate crop coordinates that ensure faces are included
x = max(0, min(face_x1, image_width - new_width))
y = max(0, min(face_y1, image_height - new_height))
# Adjust if faces would be cut off
if x + new_width < face_x2:
x = max(0, face_x2 - new_width)
if y + new_height < face_y2:
y = max(0, face_y2 - new_height)
return image.crop((x, y, x + new_width, y + new_height))
def run_flow(input_image, holiday, message):
faces = detect_faces(input_image)
cropped_image = crop_to_ratio_while_preventing_faces(input_image, faces)
prompt = image_to_prompt(cropped_image, holiday)
print(prompt)
result_image, only_background_image, mask = remove_background(cropped_image)
dilated_mask = dilate_mask(mask)
output_image = modify_background(cropped_image, dilated_mask, prompt)
# Create a copy of the modified image before drawing
output_image_with_text_rectangle = output_image.copy()
text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent = find_place_to_add_text(cropped_image, faces)
text_x = text_x_in_percent * output_image.width
text_y = text_y_in_percent * output_image.height
text_width = text_width_in_percent * output_image.width
text_height = text_height_in_percent * output_image.height
draw = ImageDraw.Draw(output_image_with_text_rectangle)
draw.rectangle((text_x, text_y, text_x + text_width, text_y + text_height), outline="red")
# Return the actual images, not the ImageDraw object
return output_image, output_image_with_text_rectangle, text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent
# Replace the demo interface
demo = gr.Interface(
fn=run_flow,
inputs=[
gr.Image(type="pil"),
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Image(type="pil", label="Output Image With Text Rectangle"),
gr.Number(label="Text Top Left X"),
gr.Number(label="Text Top Left Y"),
gr.Number(label="Text Width"),
gr.Number(label="Text Height")
],
title="Holiday Card Generator",
description="Upload an image to generate a holiday card"
)
demo.launch()