|
import gradio as gr |
|
from PIL import Image, ImageDraw, ImageFont |
|
import io |
|
import torch |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
MODEL_NAME = "HiDream-ai/HiDream-I1-Full" |
|
WATERMARK_TEXT = "SelamGPT" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
|
|
|
|
@gr.Cache() |
|
def load_model(): |
|
pipe = DiffusionPipeline.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=TORCH_DTYPE |
|
).to(DEVICE) |
|
|
|
|
|
if DEVICE == "cuda": |
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
print("Xformers not available, using default attention") |
|
pipe.enable_attention_slicing() |
|
|
|
return pipe |
|
|
|
|
|
def add_watermark(image): |
|
"""Add watermark with optimized PNG output""" |
|
try: |
|
draw = ImageDraw.Draw(image) |
|
|
|
font_size = max(24, int(image.width * 0.03)) |
|
try: |
|
font = ImageFont.truetype("Roboto-Bold.ttf", font_size) |
|
except: |
|
font = ImageFont.load_default(font_size) |
|
|
|
text_width = draw.textlength(WATERMARK_TEXT, font=font) |
|
margin = image.width * 0.02 |
|
x = image.width - text_width - margin |
|
y = image.height - (font_size * 1.5) |
|
|
|
|
|
draw.text((x+2, y+2), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 150)) |
|
|
|
draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 215, 0)) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
image.save(img_byte_arr, format='PNG', optimize=True) |
|
return Image.open(img_byte_arr) |
|
except Exception as e: |
|
print(f"Watermark error: {str(e)}") |
|
return image |
|
|
|
|
|
def generate_image(prompt): |
|
if not prompt.strip(): |
|
raise gr.Error("Please enter a prompt") |
|
|
|
try: |
|
model = load_model() |
|
result = model( |
|
prompt, |
|
num_inference_steps=30, |
|
guidance_scale=7.5, |
|
width=1024, |
|
height=1024 |
|
) |
|
return add_watermark(result.images[0]), "π¨ Generation complete!" |
|
|
|
except torch.cuda.OutOfMemoryError: |
|
raise gr.Error("Out of memory! Try a simpler prompt or smaller image size") |
|
except Exception as e: |
|
raise gr.Error(f"Generation failed: {str(e)[:200]}") |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default( |
|
primary_hue="emerald", |
|
secondary_hue="gold", |
|
font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"] |
|
)) as demo: |
|
|
|
gr.Markdown("""<h1 align="center">π¨ SelamGPT HiDream Generator</h1>""") |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=3): |
|
prompt_input = gr.Textbox( |
|
label="Describe your image", |
|
placeholder="A futuristic Ethiopian city with flying cars...", |
|
lines=3, |
|
max_lines=5, |
|
autofocus=True |
|
) |
|
generate_btn = gr.Button("Generate Image", variant="primary") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["An ancient Aksumite warrior in cyberpunk armor, 4k detailed"], |
|
["Traditional Ethiopian coffee ceremony in zero gravity"], |
|
["Portrait of a Habesha queen with golden jewelry"] |
|
], |
|
inputs=prompt_input, |
|
label="Try these prompts:" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image( |
|
label="Generated Image", |
|
type="pil", |
|
height=512, |
|
interactive=False |
|
) |
|
status = gr.Textbox( |
|
label="Status", |
|
interactive=False, |
|
show_label=False |
|
) |
|
|
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=prompt_input, |
|
outputs=[output_image, status], |
|
api_name="generate" |
|
) |
|
|
|
|
|
prompt_input.submit( |
|
fn=generate_image, |
|
inputs=prompt_input, |
|
outputs=[output_image, status] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |