File size: 3,527 Bytes
882e052
dc45b2e
c8f1f54
dc45b2e
4fe456a
c8f1f54
dc45b2e
882e052
dc45b2e
 
c8f1f54
dc45b2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882e052
5d4b17c
dc45b2e
9aeab3c
dc45b2e
 
 
 
 
 
 
 
 
c8f1f54
dc45b2e
882e052
 
087c578
 
dc45b2e
 
 
 
 
 
 
 
 
 
087c578
dc45b2e
 
 
 
 
 
faa166e
dc45b2e
 
882e052
dc45b2e
 
882e052
 
 
dc45b2e
882e052
faa166e
dc45b2e
4cda5f1
882e052
dc45b2e
087c578
 
 
dc45b2e
 
 
087c578
5d4b17c
087c578
 
dc45b2e
087c578
dc45b2e
5d4b17c
dc45b2e
 
087c578
 
 
5d4b17c
882e052
 
 
 
 
087c578
882e052
c8f1f54
 
dc45b2e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import torch
import gradio as gr
from diffusers import DiffusionPipeline
from PIL import Image, ImageDraw, ImageFont

# ===== FREE-TIER CONFIG =====
WATERMARK_TEXT = "SelamGPT"
MODEL_NAME = "DeepFloyd/IF-II-L-v1.0"
CACHE_DIR = "model_cache"  # For free tier storage limits

# ===== LIGHTWEIGHT MODEL LOAD =====
pipe = None  # Lazy load to avoid cold start timeouts

def load_model():
    global pipe
    if pipe is None:
        pipe = DiffusionPipeline.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float16,  # 50% VRAM reduction
            variant="fp16",
            cache_dir=CACHE_DIR
        )
        pipe.enable_model_cpu_offload()  # Critical for free-tier VRAM

# ===== OPTIMIZED WATERMARK =====
def add_watermark(image):
    try:
        draw = ImageDraw.Draw(image)
        font = ImageFont.load_default(20)  # No external font needed
        text_width = draw.textlength(WATERMARK_TEXT, font=font)
        draw.text(
            (image.width - text_width - 15, image.height - 30),
            WATERMARK_TEXT,
            font=font,
            fill=(255, 255, 255)
        )
        return image
    except Exception:
        return image

# ===== FREE-TIER GENERATION =====
def generate_image(prompt):
    if not prompt.strip():
        return None, "⚠️ Please enter a prompt"
    
    try:
        load_model()  # Lazy load only when needed
        
        # Free-tier optimized settings
        result = pipe(
            prompt=prompt,
            output_type="pil",
            generator=torch.Generator().manual_seed(42),  # Consistent results
            num_inference_steps=30,  # Reduced from default 50
            guidance_scale=7.0  # Balanced creativity/quality
        )
        
        return add_watermark(result.images[0]), "✔️ Generated (Free Tier)"
    except torch.cuda.OutOfMemoryError:
        return None, "⚠️ Out of VRAM - Try simpler prompt"
    except Exception as e:
        return None, f"⚠️ Error: {str(e)[:100]}"

# ===== GRADIO UI =====
with gr.Blocks(title="SelamGPT Pro") as demo:
    gr.Markdown("""
    # 🎨 SelamGPT (DeepFloyd IF-II-L)
    *Optimized for Free Tier - 64px Base Resolution*
    """)
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Describe your image",
                placeholder="A traditional Ethiopian market...",
                lines=3
            )
            generate_btn = gr.Button("Generate", variant="primary")
            
            gr.Examples(
                examples=[
                    ["Habesha cultural dress with intricate patterns, studio lighting"],
                    ["Lalibela rock-hewn churches at golden hour"],
                    ["Addis Ababa futuristic skyline, cyberpunk style"]
                ],
                inputs=prompt_input
            )
            
        with gr.Column():
            output_image = gr.Image(
                label="Generated Image",
                type="pil",
                format="webp",  # Lightweight format
                height=400
            )
            status_output = gr.Textbox(
                label="Status",
                interactive=False
            )
    
    generate_btn.click(
        fn=generate_image,
        inputs=prompt_input,
        outputs=[output_image, status_output]
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        enable_queue=False  # Critical for free tier
    )