Spaces:
Sleeping
Sleeping
import torch | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image, ImageDraw, ImageFont | |
import os | |
import gradio as gr # Import Gradio for the interface | |
#Function to generate image | |
def generate_image(prompt, height=512, width=512): | |
model_id = "CompVis/stable-diffusion-v1-4" | |
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY") | |
if not HUGGINGFACE_API_KEY: | |
raise ValueError("Hugging Face API key is not set. Export it as HUGGINGFACE_API_KEY.") | |
# Use half-precision and reduce model load time | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
use_auth_token=HUGGINGFACE_API_KEY, | |
torch_dtype=torch.float16, | |
revision="fp16" | |
) | |
pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu") | |
image = pipeline(prompt, height=height, width=width).images[0] | |
return image | |
# Function to Add Text to Image | |
def add_text_to_image(image, product_name, tagline, cta_text, font_size=50): | |
""" | |
Add clean and sharp text to the generated image. | |
""" | |
draw = ImageDraw.Draw(image) | |
try: | |
product_font = ImageFont.truetype("arial.ttf", font_size + 20) | |
tagline_font = ImageFont.truetype("arial.ttf", font_size) | |
cta_font = ImageFont.truetype("arial.ttf", font_size - 10) | |
except IOError: | |
product_font = tagline_font = cta_font = ImageFont.load_default() | |
# Add product name, tagline, and CTA to the image | |
draw.text((50, 50), product_name, font=product_font, fill="white") | |
draw.text((50, 150), tagline, font=tagline_font, fill="white") | |
draw.text((50, 250), cta_text, font=cta_font, fill="gold") | |
return image | |
# Main function to generate advertisement | |
def generate_advertisement(brand_title, tagline, cta, custom_prompt=None, brand_logo=None, product_image=None): | |
""" | |
Generate advertisement image with text overlay and optional logo/product image. | |
""" | |
prompt = custom_prompt if custom_prompt else ( | |
f"An elegant advertisement for {brand_title}, featuring gold and white tones, " | |
f"with a radiant and premium look. Product focus and beautiful typography for '{tagline}'." | |
) | |
# Generate the base image using Stable Diffusion | |
generated_image = generate_image(prompt) | |
# Overlay text (brand title, tagline, and CTA) | |
final_image = add_text_to_image(generated_image, brand_title, tagline, cta) | |
# Optionally add logo and product images | |
if brand_logo: | |
logo = Image.open(brand_logo).resize((150, 150)) | |
final_image.paste(logo, (50, 350), logo.convert('RGBA')) | |
if product_image: | |
product = Image.open(product_image).resize((300, 300)) | |
final_image.paste(product, (250, 350), product.convert('RGBA')) | |
return final_image | |
# Gradio Interface | |
def gradio_interface(brand_title, tagline, cta, custom_prompt, brand_logo, product_image): | |
""" | |
Gradio interface wrapper to call the advertisement generation function. | |
""" | |
# Generate the ad | |
ad_image = generate_advertisement( | |
brand_title=brand_title, | |
tagline=tagline, | |
cta=cta, | |
custom_prompt=custom_prompt, | |
brand_logo=brand_logo.name if brand_logo else None, | |
product_image=product_image.name if product_image else None | |
) | |
return ad_image | |
# Gradio UI Layout | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Brand Title", placeholder="e.g., GlowWell Skin Serum"), | |
gr.Textbox(label="Tagline", placeholder="e.g., Radiance Redefined"), | |
gr.Textbox(label="Call to Action (CTA)", placeholder="e.g., Shop Now"), | |
gr.Textbox(label="Custom Prompt (Optional)", placeholder="Describe your ad style..."), | |
gr.File(label="Brand Logo (Optional)"), | |
gr.File(label="Product Image (Optional)") | |
], | |
outputs=gr.Image(type="pil", label="Generated Advertisement"), | |
title="AI-Powered Advertisement Generator", | |
description="Generate stunning advertisements using Stable Diffusion. Provide brand details, and optionally upload images or add custom descriptions to create your perfect ad." | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
interface.launch() | |