snackshell's picture
Update app.py
d76eec6 verified
raw
history blame
6.07 kB
import os
import requests
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import io
import time
from concurrent.futures import ThreadPoolExecutor
# ===== CONFIGURATION =====
HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
WATERMARK_TEXT = "SelamGPT"
MAX_RETRIES = 3
TIMEOUT = 60
EXECUTOR = ThreadPoolExecutor(max_workers=2)
GEEZ_FONT_PATH = "fonts/NotoSansEthiopic-Bold.ttf"
# ===== AUTHENTIC ETHIOPIAN THEME =====
theme = gr.themes.Default(
primary_hue="green",
secondary_hue="yellow",
neutral_hue="red",
font=[gr.themes.GoogleFont("Noto Sans Ethiopic"), "Arial", "sans-serif"]
).set(
button_primary_background_fill="linear-gradient(90deg, #078C03 0%, #FCDD09 50%, #DA1212 100%)", # Ethiopian flag gradient
button_primary_text_color="#ffffff",
button_secondary_background_fill="#FCDD09", # Yellow
slider_color="#DA1212" # Red
)
# ===== WATERMARK FUNCTION =====
def add_watermark(image_bytes):
try:
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
draw = ImageDraw.Draw(image)
# Try Ge'ez font first, then default
try:
font = ImageFont.truetype(GEEZ_FONT_PATH, 30)
except:
font = ImageFont.load_default(30)
# Calculate position
text_width = draw.textlength(WATERMARK_TEXT, font=font)
x = image.width - text_width - 20
y = image.height - 40
# Draw with Ethiopian flag colors (green outline, yellow fill)
draw.text((x+2, y+2), WATERMARK_TEXT, font=font, fill="#078C03") # Green shadow
draw.text((x, y), WATERMARK_TEXT, font=font, fill="#FCDD09") # Yellow text
return image
except Exception as e:
print(f"Watermark error: {str(e)}")
return Image.open(io.BytesIO(image_bytes))
# ===== IMAGE GENERATION =====
def generate_image(prompt):
if not prompt.strip():
return None, "⚠️ Please enter a prompt"
# Ethiopian cultural enhancement
cultural_enhancements = [
"Ethiopian style", "Habesha culture", "vibrant colors",
"traditional elements", "East African aesthetic"
]
enhanced_prompt = f"{prompt}, {', '.join(cultural_enhancements)}"
params = {
"height": 1024,
"width": 1024,
"num_inference_steps": 30,
"guidance_scale": 8.0,
"negative_prompt": "western, non-African, cartoonish, low quality"
}
def api_call():
return requests.post(
API_URL,
headers=headers,
json={
"inputs": enhanced_prompt,
"parameters": params,
"options": {"wait_for_model": True}
},
timeout=TIMEOUT
)
for attempt in range(MAX_RETRIES):
try:
future = EXECUTOR.submit(api_call)
response = future.result()
if response.status_code == 200:
return add_watermark(response.content), "✔️ Generation successful"
elif response.status_code == 503:
wait_time = (attempt + 1) * 15
print(f"Model loading, waiting {wait_time}s...")
time.sleep(wait_time)
continue
else:
return None, f"⚠️ API Error: {response.text[:200]}"
except requests.Timeout:
return None, f"⚠️ Timeout: Model took >{TIMEOUT}s to respond"
except Exception as e:
return None, f"⚠️ Unexpected error: {str(e)[:200]}"
return None, "⚠️ Failed after multiple attempts. Please try later."
# ===== GRADIO INTERFACE =====
with gr.Blocks(theme=theme, title="SelamGPT - Ethiopian Image Generator") as demo:
with gr.Row():
gr.Image("ethiopia_flag.png", height=80, show_label=False, show_download_button=False) # Add your flag image
gr.Markdown("""
<center>
<h1 style="color:#078C03">S</h1>
<h1 style="color:#FCDD09">E</h1>
<h1 style="color:#DA1212">L</h1>
<h1 style="color:#078C03">A</h1>
<h1 style="color:#FCDD09">M</h1>
<h1 style="color:#DA1212">G</h1>
<h1 style="color:#078C03">P</h1>
<h1 style="color:#FCDD09">T</h1>
</center>
<center><i>Ethiopian AI Image Generator</i></center>
""")
with gr.Row():
with gr.Column(scale=3):
prompt_input = gr.Textbox(
label="Describe your image (English or Amharic)",
placeholder="ሰላም! Describe an Ethiopian scene...",
lines=3
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
gr.Examples(
examples=[
["አዲስ አበባ በሌሊት ፣ የኢትዮጵያ ባህላዊ መብራቶች", "Addis Ababa at night with Ethiopian lights"],
["Habesha cultural dress with intricate embroidery"],
["Simien Mountains landscape with gelada monkeys"]
],
inputs=prompt_input
)
with gr.Column(scale=2):
output_image = gr.Image(
label="Generated Image",
type="pil",
format="png",
height=512
)
status_output = gr.Textbox(
label="Status",
interactive=False
)
generate_btn.click(
fn=generate_image,
inputs=prompt_input,
outputs=[output_image, status_output],
queue=True
)
clear_btn.click(
fn=lambda: [None, ""],
outputs=[output_image, status_output]
)
if __name__ == "__main__":
demo.queue(max_size=2)
demo.launch(server_name="0.0.0.0", server_port=7860)