holiday_cards / app.py
Amit Gazal
wip
666c963
raw
history blame
6.26 kB
import gradio as gr
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from openai import OpenAI
import os
import base64
import io
import requests
import numpy as np
from scipy import ndimage
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY')
IDEOGRAM_URL = "https://api.ideogram.ai/edit"
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
# Constants should be in UPPERCASE
GPT_MODEL_NAME = "gpt-4o"
GPT_MAX_TOKENS = 500
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
if torch.cuda.is_available():
model = model.to('cuda')
model.eval()
GPT_PROMPT = '''
I work with a tool that knows how to edit backgrounds.
I want your help with prompt.
I want to adjust their background to be in a christmas vibes.
For example, if you see a tree there, cover it in snow,
add christmas lights to some of the stuff in the background, maybe add a few elements like christmas tree, but take into considration the perspective and the logic of the image.
'''
def image_to_prompt(image: str) -> tuple[str, str]:
base64_image = encode_image(image)
messages = [{
"role": "user",
"content": [
{"type": "text", "text": GPT_PROMPT},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]
}]
response = client.chat.completions.create(
model=GPT_MODEL_NAME,
messages=messages,
max_tokens=GPT_MAX_TOKENS
)
full_response = response.choices[0].message.content
return full_response
def encode_image(image: Image.Image) -> str:
"""Convert a PIL Image to base64 encoded string.
Args:
image (PIL.Image.Image): The PIL Image to encode
Returns:
str: Base64 encoded image string
"""
# Create a temporary buffer to save the image
buffer = io.BytesIO()
# Save the image as PNG to the buffer
image.save(buffer, format='PNG')
# Get the bytes from the buffer and encode to base64
return base64.b64encode(buffer.getvalue()).decode('utf-8')
def remove_background(input_image):
image_size = (1024, 1024)
# Transform the input image
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Process the image
input_tensor = transform_image(input_image).unsqueeze(0)
if torch.cuda.is_available():
input_tensor = input_tensor.to('cuda')
# Generate prediction
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(input_image.size)
# Create image without background
result_image = input_image.copy()
result_image.putalpha(mask)
# Create image with only background
only_background_image = input_image.copy()
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask
only_background_image.putalpha(inverted_mask)
return result_image, only_background_image, mask
def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image:
# Convert PIL images to bytes
image_buffer = io.BytesIO()
image.save(image_buffer, format='PNG')
image_bytes = image_buffer.getvalue()
mask_buffer = io.BytesIO()
mask.save(mask_buffer, format='PNG')
mask_bytes = mask_buffer.getvalue()
# Create the files dictionary with actual bytes data
files = {
"image_file": ("image.png", image_bytes, "image/png"),
"mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file
}
payload = {
"prompt": prompt, # Use the actual prompt parameter
"model": "V_2",
"magic_prompt_option": "ON",
"num_images": 1,
"style_type": "REALISTIC"
}
headers = {"Api-Key": IDEOGRAM_API_KEY}
response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers)
if response.status_code == 200:
# Assuming the API returns an image in the response
response_data = response.json()
# You'll need to handle the response according to Ideogram's API specification
# This is a placeholder - adjust according to actual API response format
result_image_url = response_data.get('data')[0].get('url')
if result_image_url:
result_response = requests.get(result_image_url)
return Image.open(io.BytesIO(result_response.content))
raise Exception(f"Failed to modify background: {response.text}")
def dilate_mask(mask: Image.Image) -> Image.Image:
# Convert mask to numpy array
mask_array = np.array(mask)
# Apply maximum filter using scipy.ndimage
dilated_mask = ndimage.maximum_filter(mask_array, size=20)
# Convert back to PIL Image
return Image.fromarray(dilated_mask.astype(np.uint8))
def run_flow(input_image, holiday, message):
prompt = image_to_prompt(input_image)
print(prompt)
result_image, only_background_image, mask = remove_background(input_image)
dilated_mask = dilate_mask(mask)
modified_image = modify_background(input_image, dilated_mask, prompt)
first_output_image = mask
second_output_image = dilated_mask
third_output_image = modified_image
return first_output_image, second_output_image, third_output_image
# Replace the demo interface
demo = gr.Interface(
fn=run_flow,
inputs=[
gr.Image(type="pil"),
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"),
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
],
outputs=[
gr.Image(type="pil", label="First Output"),
gr.Image(type="pil", label="Second Output"),
gr.Image(type="pil", label="Third Output")
],
title="Holiday Card Generator",
description="Upload an image to generate a holiday card"
)
demo.launch()