Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import torch | |
from torchvision import transforms | |
from transformers import AutoModelForImageSegmentation | |
from openai import OpenAI | |
import os | |
import base64 | |
import io | |
import requests | |
import numpy as np | |
from scipy import ndimage | |
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY') | |
IDEOGRAM_URL = "https://api.ideogram.ai/edit" | |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
# Constants should be in UPPERCASE | |
GPT_MODEL_NAME = "gpt-4o" | |
GPT_MAX_TOKENS = 500 | |
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) | |
torch.set_float32_matmul_precision(['high', 'highest'][0]) | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
model.eval() | |
GPT_PROMPT = ''' | |
I work with a tool that knows how to edit backgrounds. | |
I want your help with prompt. | |
I want to adjust their background to be in a christmas vibes. | |
For example, if you see a tree there, cover it in snow, | |
add christmas lights to some of the stuff in the background, maybe add a few elements like christmas tree, but take into considration the perspective and the logic of the image. | |
''' | |
def image_to_prompt(image: str) -> tuple[str, str]: | |
base64_image = encode_image(image) | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": GPT_PROMPT}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
] | |
}] | |
response = client.chat.completions.create( | |
model=GPT_MODEL_NAME, | |
messages=messages, | |
max_tokens=GPT_MAX_TOKENS | |
) | |
full_response = response.choices[0].message.content | |
return full_response | |
def encode_image(image: Image.Image) -> str: | |
"""Convert a PIL Image to base64 encoded string. | |
Args: | |
image (PIL.Image.Image): The PIL Image to encode | |
Returns: | |
str: Base64 encoded image string | |
""" | |
# Create a temporary buffer to save the image | |
buffer = io.BytesIO() | |
# Save the image as PNG to the buffer | |
image.save(buffer, format='PNG') | |
# Get the bytes from the buffer and encode to base64 | |
return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
def remove_background(input_image): | |
image_size = (1024, 1024) | |
# Transform the input image | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Process the image | |
input_tensor = transform_image(input_image).unsqueeze(0) | |
if torch.cuda.is_available(): | |
input_tensor = input_tensor.to('cuda') | |
# Generate prediction | |
with torch.no_grad(): | |
preds = model(input_tensor)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(input_image.size) | |
# Create image without background | |
result_image = input_image.copy() | |
result_image.putalpha(mask) | |
# Create image with only background | |
only_background_image = input_image.copy() | |
inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask | |
only_background_image.putalpha(inverted_mask) | |
return result_image, only_background_image, mask | |
def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image: | |
# Convert PIL images to bytes | |
image_buffer = io.BytesIO() | |
image.save(image_buffer, format='PNG') | |
image_bytes = image_buffer.getvalue() | |
mask_buffer = io.BytesIO() | |
mask.save(mask_buffer, format='PNG') | |
mask_bytes = mask_buffer.getvalue() | |
# Create the files dictionary with actual bytes data | |
files = { | |
"image_file": ("image.png", image_bytes, "image/png"), | |
"mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file | |
} | |
payload = { | |
"prompt": prompt, # Use the actual prompt parameter | |
"model": "V_2", | |
"magic_prompt_option": "ON", | |
"num_images": 1, | |
"style_type": "REALISTIC" | |
} | |
headers = {"Api-Key": IDEOGRAM_API_KEY} | |
response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers) | |
if response.status_code == 200: | |
# Assuming the API returns an image in the response | |
response_data = response.json() | |
# You'll need to handle the response according to Ideogram's API specification | |
# This is a placeholder - adjust according to actual API response format | |
result_image_url = response_data.get('data')[0].get('url') | |
if result_image_url: | |
result_response = requests.get(result_image_url) | |
return Image.open(io.BytesIO(result_response.content)) | |
raise Exception(f"Failed to modify background: {response.text}") | |
def dilate_mask(mask: Image.Image) -> Image.Image: | |
# Convert mask to numpy array | |
mask_array = np.array(mask) | |
# Apply maximum filter using scipy.ndimage | |
dilated_mask = ndimage.maximum_filter(mask_array, size=20) | |
# Convert back to PIL Image | |
return Image.fromarray(dilated_mask.astype(np.uint8)) | |
def run_flow(input_image, holiday, message): | |
prompt = image_to_prompt(input_image) | |
print(prompt) | |
result_image, only_background_image, mask = remove_background(input_image) | |
dilated_mask = dilate_mask(mask) | |
modified_image = modify_background(input_image, dilated_mask, prompt) | |
first_output_image = mask | |
second_output_image = dilated_mask | |
third_output_image = modified_image | |
return first_output_image, second_output_image, third_output_image | |
# Replace the demo interface | |
demo = gr.Interface( | |
fn=run_flow, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"), | |
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="First Output"), | |
gr.Image(type="pil", label="Second Output"), | |
gr.Image(type="pil", label="Third Output") | |
], | |
title="Holiday Card Generator", | |
description="Upload an image to generate a holiday card" | |
) | |
demo.launch() | |