snackshell's picture
Update app.py
dc45b2e verified
raw
history blame
3.53 kB
import os
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from PIL import Image, ImageDraw, ImageFont
# ===== FREE-TIER CONFIG =====
WATERMARK_TEXT = "SelamGPT"
MODEL_NAME = "DeepFloyd/IF-II-L-v1.0"
CACHE_DIR = "model_cache" # For free tier storage limits
# ===== LIGHTWEIGHT MODEL LOAD =====
pipe = None # Lazy load to avoid cold start timeouts
def load_model():
global pipe
if pipe is None:
pipe = DiffusionPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16, # 50% VRAM reduction
variant="fp16",
cache_dir=CACHE_DIR
)
pipe.enable_model_cpu_offload() # Critical for free-tier VRAM
# ===== OPTIMIZED WATERMARK =====
def add_watermark(image):
try:
draw = ImageDraw.Draw(image)
font = ImageFont.load_default(20) # No external font needed
text_width = draw.textlength(WATERMARK_TEXT, font=font)
draw.text(
(image.width - text_width - 15, image.height - 30),
WATERMARK_TEXT,
font=font,
fill=(255, 255, 255)
)
return image
except Exception:
return image
# ===== FREE-TIER GENERATION =====
def generate_image(prompt):
if not prompt.strip():
return None, "⚠️ Please enter a prompt"
try:
load_model() # Lazy load only when needed
# Free-tier optimized settings
result = pipe(
prompt=prompt,
output_type="pil",
generator=torch.Generator().manual_seed(42), # Consistent results
num_inference_steps=30, # Reduced from default 50
guidance_scale=7.0 # Balanced creativity/quality
)
return add_watermark(result.images[0]), "✔️ Generated (Free Tier)"
except torch.cuda.OutOfMemoryError:
return None, "⚠️ Out of VRAM - Try simpler prompt"
except Exception as e:
return None, f"⚠️ Error: {str(e)[:100]}"
# ===== GRADIO UI =====
with gr.Blocks(title="SelamGPT Pro") as demo:
gr.Markdown("""
# 🎨 SelamGPT (DeepFloyd IF-II-L)
*Optimized for Free Tier - 64px Base Resolution*
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Describe your image",
placeholder="A traditional Ethiopian market...",
lines=3
)
generate_btn = gr.Button("Generate", variant="primary")
gr.Examples(
examples=[
["Habesha cultural dress with intricate patterns, studio lighting"],
["Lalibela rock-hewn churches at golden hour"],
["Addis Ababa futuristic skyline, cyberpunk style"]
],
inputs=prompt_input
)
with gr.Column():
output_image = gr.Image(
label="Generated Image",
type="pil",
format="webp", # Lightweight format
height=400
)
status_output = gr.Textbox(
label="Status",
interactive=False
)
generate_btn.click(
fn=generate_image,
inputs=prompt_input,
outputs=[output_image, status_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
enable_queue=False # Critical for free tier
)