|
import os |
|
import requests |
|
import gradio as gr |
|
from PIL import Image, ImageDraw, ImageFont |
|
import io |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
HF_API_TOKEN = os.environ.get("HF_API_TOKEN") |
|
MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" |
|
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}" |
|
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} |
|
WATERMARK_TEXT = "SelamGPT" |
|
MAX_RETRIES = 3 |
|
TIMEOUT = 60 |
|
EXECUTOR = ThreadPoolExecutor(max_workers=2) |
|
GEEZ_FONT_PATH = "fonts/NotoSansEthiopic-Bold.ttf" |
|
|
|
|
|
theme = gr.themes.Default( |
|
primary_hue="green", |
|
secondary_hue="yellow", |
|
neutral_hue="red", |
|
font=[gr.themes.GoogleFont("Noto Sans Ethiopic"), "Arial", "sans-serif"] |
|
).set( |
|
button_primary_background_fill="linear-gradient(90deg, #078C03 0%, #FCDD09 50%, #DA1212 100%)", |
|
button_primary_text_color="#ffffff", |
|
button_secondary_background_fill="#FCDD09", |
|
slider_color="#DA1212" |
|
) |
|
|
|
|
|
def add_watermark(image_bytes): |
|
try: |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
try: |
|
font = ImageFont.truetype(GEEZ_FONT_PATH, 30) |
|
except: |
|
font = ImageFont.load_default(30) |
|
|
|
|
|
text_width = draw.textlength(WATERMARK_TEXT, font=font) |
|
x = image.width - text_width - 20 |
|
y = image.height - 40 |
|
|
|
|
|
draw.text((x+2, y+2), WATERMARK_TEXT, font=font, fill="#078C03") |
|
draw.text((x, y), WATERMARK_TEXT, font=font, fill="#FCDD09") |
|
|
|
return image |
|
except Exception as e: |
|
print(f"Watermark error: {str(e)}") |
|
return Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
def generate_image(prompt): |
|
if not prompt.strip(): |
|
return None, "⚠️ Please enter a prompt" |
|
|
|
|
|
cultural_enhancements = [ |
|
"Ethiopian style", "Habesha culture", "vibrant colors", |
|
"traditional elements", "East African aesthetic" |
|
] |
|
|
|
enhanced_prompt = f"{prompt}, {', '.join(cultural_enhancements)}" |
|
|
|
params = { |
|
"height": 1024, |
|
"width": 1024, |
|
"num_inference_steps": 30, |
|
"guidance_scale": 8.0, |
|
"negative_prompt": "western, non-African, cartoonish, low quality" |
|
} |
|
|
|
def api_call(): |
|
return requests.post( |
|
API_URL, |
|
headers=headers, |
|
json={ |
|
"inputs": enhanced_prompt, |
|
"parameters": params, |
|
"options": {"wait_for_model": True} |
|
}, |
|
timeout=TIMEOUT |
|
) |
|
|
|
for attempt in range(MAX_RETRIES): |
|
try: |
|
future = EXECUTOR.submit(api_call) |
|
response = future.result() |
|
|
|
if response.status_code == 200: |
|
return add_watermark(response.content), "✔️ Generation successful" |
|
elif response.status_code == 503: |
|
wait_time = (attempt + 1) * 15 |
|
print(f"Model loading, waiting {wait_time}s...") |
|
time.sleep(wait_time) |
|
continue |
|
else: |
|
return None, f"⚠️ API Error: {response.text[:200]}" |
|
except requests.Timeout: |
|
return None, f"⚠️ Timeout: Model took >{TIMEOUT}s to respond" |
|
except Exception as e: |
|
return None, f"⚠️ Unexpected error: {str(e)[:200]}" |
|
|
|
return None, "⚠️ Failed after multiple attempts. Please try later." |
|
|
|
|
|
with gr.Blocks(theme=theme, title="SelamGPT - Ethiopian Image Generator") as demo: |
|
with gr.Row(): |
|
gr.Image("ethiopia_flag.png", height=80, show_label=False, show_download_button=False) |
|
|
|
gr.Markdown(""" |
|
<center> |
|
<h1 style="color:#078C03">S</h1> |
|
<h1 style="color:#FCDD09">E</h1> |
|
<h1 style="color:#DA1212">L</h1> |
|
<h1 style="color:#078C03">A</h1> |
|
<h1 style="color:#FCDD09">M</h1> |
|
<h1 style="color:#DA1212">G</h1> |
|
<h1 style="color:#078C03">P</h1> |
|
<h1 style="color:#FCDD09">T</h1> |
|
</center> |
|
<center><i>Ethiopian AI Image Generator</i></center> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
prompt_input = gr.Textbox( |
|
label="Describe your image (English or Amharic)", |
|
placeholder="ሰላም! Describe an Ethiopian scene...", |
|
lines=3 |
|
) |
|
with gr.Row(): |
|
generate_btn = gr.Button("Generate", variant="primary") |
|
clear_btn = gr.Button("Clear") |
|
|
|
gr.Examples( |
|
examples=[ |
|
["አዲስ አበባ በሌሊት ፣ የኢትዮጵያ ባህላዊ መብራቶች", "Addis Ababa at night with Ethiopian lights"], |
|
["Habesha cultural dress with intricate embroidery"], |
|
["Simien Mountains landscape with gelada monkeys"] |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image( |
|
label="Generated Image", |
|
type="pil", |
|
format="png", |
|
height=512 |
|
) |
|
status_output = gr.Textbox( |
|
label="Status", |
|
interactive=False |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_image, |
|
inputs=prompt_input, |
|
outputs=[output_image, status_output], |
|
queue=True |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: [None, ""], |
|
outputs=[output_image, status_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=2) |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |