Spaces:
Sleeping
Sleeping
File size: 13,091 Bytes
e25fc54 eb5c95c e25fc54 666c963 eb5c95c 666c963 eb5c95c 666c963 e25fc54 666c963 eb5c95c 522f227 666c963 eb5c95c 666c963 eb5c95c 666c963 e25fc54 666c963 522f227 666c963 eb5c95c 80d8afe eb5c95c 666c963 eb5c95c 80d8afe 666c963 80d8afe 666c963 80d8afe eb5c95c 80d8afe eb5c95c 80d8afe eb5c95c e25fc54 666c963 e25fc54 eb5c95c e25fc54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
import gradio as gr
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from openai import OpenAI
import os
import base64
import io
import requests
import numpy as np
from scipy import ndimage
from insightface.app import FaceAnalysis
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY')
IDEOGRAM_URL = "https://api.ideogram.ai/edit"
face_detection_app = FaceAnalysis(allowed_modules=['detection']) # enable detection model only
face_detection_app.prepare(ctx_id=0, det_size=(640, 640))
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
# Constants should be in UPPERCASE
GPT_MODEL_NAME = "gpt-4o"
GPT_MAX_TOKENS = 500
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
if torch.cuda.is_available():
model = model.to('cuda')
model.eval()
GPT_PROMPT = '''
You are a background editor.
Your job is to adjust the background of the image to be in a {{holiday}} vibes, but take into considration the perspective and the logic of the image.
Your output should be a prompt that can be used to edit the background of the image.
The background should be edited in a way that is consistent with the image.
The prompt should not include any text or writing in the background.
'''
def image_to_prompt(image: str, holiday: str) -> tuple[str, str]:
base64_image = encode_image(image)
messages = [{
"role": "user",
"content": [
{"type": "text", "text": GPT_PROMPT.replace("{{holiday}}", holiday)},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]
}]
response = client.chat.completions.create(
model=GPT_MODEL_NAME,
messages=messages,
max_tokens=GPT_MAX_TOKENS
)
full_response = response.choices[0].message.content
return full_response
def encode_image(image: Image.Image) -> str:
"""Convert a PIL Image to base64 encoded string.
Args:
image (PIL.Image.Image): The PIL Image to encode
Returns:
str: Base64 encoded image string
"""
# Create a temporary buffer to save the image
buffer = io.BytesIO()
# Save the image as PNG to the buffer
image.save(buffer, format='PNG')
# Get the bytes from the buffer and encode to base64
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def remove_background(input_image):
image_size = (1024, 1024)
# Transform the input image
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Process the image
input_tensor = transform_image(input_image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.to('cuda')
# Generate prediction
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
# Create image without background
result_image = input_image.copy()
result_image.putalpha(mask)
# Create image with only background
only_background_image = input_image.copy()
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
only_background_image.putalpha(inverted_mask)
return result_image, only_background_image, mask
def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image:
# Convert PIL images to bytes
image_buffer = io.BytesIO()
image.save(image_buffer, format='PNG')
image_bytes = image_buffer.getvalue()
mask_buffer = io.BytesIO()
mask.save(mask_buffer, format='PNG')
mask_bytes = mask_buffer.getvalue()
# Create the files dictionary with actual bytes data
files = {
"image_file": ("image.png", image_bytes, "image/png"),
"mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file
}
prevent_text_in_background = "Do not include any text or writing in the background."
prompt = f"{prompt} {prevent_text_in_background}"
payload = {
"prompt": prompt, # Use the actual prompt parameter
"model": "V_2",
"magic_prompt_option": "ON",
"num_images": 1,
"style_type": "REALISTIC"
}
headers = {"Api-Key": IDEOGRAM_API_KEY}
response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers)
if response.status_code == 200:
# Assuming the API returns an image in the response
response_data = response.json()
# You'll need to handle the response according to Ideogram's API specification
# This is a placeholder - adjust according to actual API response format
result_image_url = response_data.get('data')[0].get('url')
if result_image_url:
result_response = requests.get(result_image_url)
return Image.open(io.BytesIO(result_response.content))
raise Exception(f"Failed to modify background: {response.text}")
def dilate_mask(mask: Image.Image) -> Image.Image:
# Convert mask to numpy array
mask_array = np.array(mask)
# Apply maximum filter using scipy.ndimage
dilated_mask = ndimage.maximum_filter(mask_array, size=20)
# Convert back to PIL Image
return Image.fromarray(dilated_mask.astype(np.uint8))
def detect_faces(image: Image.Image) -> list[dict]:
# Convert PIL Image to numpy array
image_np = np.array(image)
faces = face_detection_app.get(image_np)
return faces
def check_text_position(x, y, text_rect_width, text_rect_height, face_rects, image_width, image_height):
# Calculate text rectangle bounds
text_x1 = x - text_rect_width//2
text_y1 = y - text_rect_height//2
text_x2 = x + text_rect_width//2
text_y2 = y + text_rect_height//2
# Check if text is within image bounds
if (text_x1 < 0 or text_x2 > image_width or
text_y1 < 0 or text_y2 > image_height):
return False
# Check for collision with any face
for face_rect in face_rects:
fx1, fy1, fx2, fy2 = face_rect
# Check if rectangles overlap
if not (text_x2 < fx1 or text_x1 > fx2 or text_y2 < fy1 or text_y1 > fy2):
return False
return True
def find_place_to_add_text(image: Image.Image, faces: list[dict]) -> tuple[int, int]:
image_width, image_height = image.size
# Convert face coordinates to rectangles for collision detection
face_rects = []
padding = 20 # Padding around faces
for face in faces:
bbox = face.bbox # Get bounding box coordinates
x1, y1, x2, y2 = map(int, bbox)
face_rects.append((
max(0, x1-padding),
max(0, y1-padding),
min(image_width, x2+padding),
min(image_height, y2+padding)
))
# Define possible text positions
padding_x = int(0.1 * image_width)
padding_y = int(0.1 * image_height)
positions = [
(image_width//2, int(0.85*image_height) - padding_y), # Bottom center
(image_width//2, int(0.15*image_height) + padding_y), # Top center
(int(0.15*image_width) + padding_x, image_height//2), # Left middle
(int(0.85*image_width) - padding_x, image_height//2) # Right middle
]
# Start with largest desired text size and gradually reduce
current_text_width = 0.8
current_text_height = 0.3
min_text_width = 0.1
min_text_height = 0.03
reduction_factor = 0.9 # Reduce size by 10% each iteration
while current_text_width >= min_text_width and current_text_height >= min_text_height:
text_rect_width = current_text_width * image_width
text_rect_height = current_text_height * image_height
# Try each position with current size
for x, y in positions:
if check_text_position(x, y, text_rect_width, text_rect_height,
face_rects, image_width, image_height):
top_left_x_in_percent = (x - text_rect_width//2) / image_width
top_left_y_in_percent = (y - text_rect_height//2) / image_height
return top_left_x_in_percent, top_left_y_in_percent, current_text_width, current_text_height
# If no position works, reduce text size and try again
current_text_width *= reduction_factor
current_text_height *= reduction_factor
# If we get here, return bottom center with minimum size as fallback
print("Failed to find a suitable position")
# Return bottom center with minimum size as fallback
return (
(image_width//2 - (min_text_width * image_width)//2) / image_width, # x position
(int(0.85*image_height) - (min_text_height * image_height)//2) / image_height, # y position
min_text_width, # width
min_text_height # height
)
def crop_to_ratio_while_preventing_faces(image: Image.Image, faces: list[dict]) -> Image.Image:
ASPECT_RATIO_PORTRAIT = 5/7
ASPECT_RATIO_LANDSCAPE = 7/5
image_width, image_height = image.size
# Calculate current aspect ratio
current_ratio = image_width / image_height
is_portrait = current_ratio < 1
target_ratio = ASPECT_RATIO_PORTRAIT if is_portrait else ASPECT_RATIO_LANDSCAPE
# Calculate new dimensions
if current_ratio > target_ratio:
new_width = int(image_height * target_ratio)
new_height = image_height
else:
new_width = image_width
new_height = int(image_width / target_ratio)
# If no faces, just do center crop
if not faces:
x = (image_width - new_width) // 2
y = (image_height - new_height) // 2
return image.crop((x, y, x + new_width, y + new_height))
# Find the bounding box that contains all faces
face_x1 = min(int(face['bbox'][0]) for face in faces)
face_y1 = min(int(face['bbox'][1]) for face in faces)
face_x2 = max(int(face['bbox'][2]) for face in faces)
face_y2 = max(int(face['bbox'][3]) for face in faces)
# Add padding around faces
padding = 50
face_x1 = max(0, face_x1 - padding)
face_y1 = max(0, face_y1 - padding)
face_x2 = min(image_width, face_x2 + padding)
face_y2 = min(image_height, face_y2 + padding)
# Calculate crop coordinates that ensure faces are included
x = max(0, min(face_x1, image_width - new_width))
y = max(0, min(face_y1, image_height - new_height))
# Adjust if faces would be cut off
if x + new_width < face_x2:
x = max(0, face_x2 - new_width)
if y + new_height < face_y2:
y = max(0, face_y2 - new_height)
return image.crop((x, y, x + new_width, y + new_height))
def run_flow(input_image, holiday, message):
faces = detect_faces(input_image)
cropped_image = crop_to_ratio_while_preventing_faces(input_image, faces)
prompt = image_to_prompt(cropped_image, holiday)
print(prompt)
result_image, only_background_image, mask = remove_background(cropped_image)
dilated_mask = dilate_mask(mask)
output_image = modify_background(cropped_image, dilated_mask, prompt)
# Create a copy of the modified image before drawing
output_image_with_text_rectangle = output_image.copy()
text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent = find_place_to_add_text(cropped_image, faces)
text_x = text_x_in_percent * output_image.width
text_y = text_y_in_percent * output_image.height
text_width = text_width_in_percent * output_image.width
text_height = text_height_in_percent * output_image.height
draw = ImageDraw.Draw(output_image_with_text_rectangle)
draw.rectangle((text_x, text_y, text_x + text_width, text_y + text_height), outline="red")
# Return the actual images, not the ImageDraw object
return output_image, output_image_with_text_rectangle, text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent
# Replace the demo interface
demo = gr.Interface(
fn=run_flow,
inputs=[
gr.Image(type="pil"),
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
],
outputs=[
gr.Image(type="pil", label="Output Image"),
gr.Image(type="pil", label="Output Image With Text Rectangle"),
gr.Number(label="Text Top Left X"),
gr.Number(label="Text Top Left Y"),
gr.Number(label="Text Width"),
gr.Number(label="Text Height")
],
title="Holiday Card Generator",
description="Upload an image to generate a holiday card"
)
demo.launch()
|