|
import os |
|
import torch |
|
import gradio as gr |
|
from diffusers import DiffusionPipeline |
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
WATERMARK_TEXT = "SelamGPT" |
|
MODEL_NAME = "DeepFloyd/IF-II-L-v1.0" |
|
CACHE_DIR = "model_cache" |
|
|
|
|
|
pipe = None |
|
|
|
def load_model(): |
|
global pipe |
|
if pipe is None: |
|
pipe = DiffusionPipeline.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
cache_dir=CACHE_DIR |
|
) |
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
def add_watermark(image): |
|
try: |
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.load_default(20) |
|
text_width = draw.textlength(WATERMARK_TEXT, font=font) |
|
draw.text( |
|
(image.width - text_width - 15, image.height - 30), |
|
WATERMARK_TEXT, |
|
font=font, |
|
fill=(255, 255, 255) |
|
) |
|
return image |
|
except Exception: |
|
return image |
|
|
|
|
|
def generate_image(prompt): |
|
if not prompt.strip(): |
|
return None, "⚠️ Please enter a prompt" |
|
|
|
try: |
|
load_model() |
|
|
|
|
|
result = pipe( |
|
prompt=prompt, |
|
output_type="pil", |
|
generator=torch.Generator().manual_seed(42), |
|
num_inference_steps=30, |
|
guidance_scale=7.0 |
|
) |
|
|
|
return add_watermark(result.images[0]), "✔️ Generated (Free Tier)" |
|
except torch.cuda.OutOfMemoryError: |
|
return None, "⚠️ Out of VRAM - Try simpler prompt" |
|
except Exception as e: |
|
return None, f"⚠️ Error: {str(e)[:100]}" |
|
|
|
|
|
with gr.Blocks(title="SelamGPT Pro") as demo: |
|
gr.Markdown(""" |
|
# 🎨 SelamGPT (DeepFloyd IF-II-L) |
|
*Optimized for Free Tier - 64px Base Resolution* |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_input = gr.Textbox( |
|
label="Describe your image", |
|
placeholder="A traditional Ethiopian market...", |
|
lines=3 |
|
) |
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["Habesha cultural dress with intricate patterns, studio lighting"], |
|
["Lalibela rock-hewn churches at golden hour"], |
|
["Addis Ababa futuristic skyline, cyberpunk style"] |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
with gr.Column(): |
|
output_image = gr.Image( |
|
label="Generated Image", |
|
type="pil", |
|
format="webp", |
|
height=400 |
|
) |
|
status_output = gr.Textbox( |
|
label="Status", |
|
interactive=False |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=prompt_input, |
|
outputs=[output_image, status_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
enable_queue=False |
|
) |