snackshell's picture
Update app.py
a346948 verified
raw
history blame
2.58 kB
import os
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import io
import torch
from diffusers import DiffusionPipeline
# ===== CONFIGURATION =====
MODEL_NAME = "HiDream-ai/HiDream-I1-Full"
WATERMARK_TEXT = "SelamGPT"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# ===== MODEL LOADING =====
# Global variable for model caching (alternative to @gr.Cache)
pipe = None
def load_model():
global pipe
if pipe is None:
pipe = DiffusionPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=TORCH_DTYPE
).to(DEVICE)
# Optimizations
if DEVICE == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except:
print("Xformers not available, using default attention")
pipe.enable_attention_slicing()
return pipe
# ===== WATERMARK FUNCTION =====
def add_watermark(image):
"""Add watermark with optimized PNG output"""
try:
draw = ImageDraw.Draw(image)
font_size = 24
try:
font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
except:
font = ImageFont.load_default(font_size)
text_width = draw.textlength(WATERMARK_TEXT, font=font)
x = image.width - text_width - 10
y = image.height - 34
draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128))
draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255))
# Convert to optimized PNG
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG', optimize=True, quality=85)
img_byte_arr.seek(0)
return Image.open(img_byte_arr)
except Exception as e:
print(f"Watermark error: {str(e)}")
return image
# ===== IMAGE GENERATION =====
def generate_image(prompt):
if not prompt.strip():
return None, "⚠️ Please enter a prompt"
try:
model = load_model()
image = model(
prompt,
num_inference_steps=30,
guidance_scale=7.5
).images[0]
return add_watermark(image), "✔️ Generation successful"
except torch.cuda.OutOfMemoryError:
return None, "⚠️ Out of memory! Try a simpler prompt"
except Exception as e:
return None, f"⚠️ Error: {str(e)[:200]}"
# ===== GRADIO INTERFACE =====
# ... (keep your existing interface code exactly as is) ...