snackshell commited on
Commit
087c578
·
verified ·
1 Parent(s): f60d82f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -51
app.py CHANGED
@@ -4,107 +4,175 @@ import gradio as gr
4
  from PIL import Image, ImageDraw, ImageFont
5
  import io
6
  import time
 
7
 
8
- # Configuration
9
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
10
- MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" # High-quality model
11
  API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
12
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
13
  WATERMARK_TEXT = "SelamGPT"
14
  MAX_RETRIES = 3
 
 
15
 
 
16
  def add_watermark(image_bytes):
17
- """Add watermark to generated image"""
18
  try:
19
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
20
  draw = ImageDraw.Draw(image)
21
 
22
- # Use default font (no need for external files)
23
- try:
24
- font = ImageFont.truetype("arialbd.ttf", 40)
25
- except:
26
- font = ImageFont.load_default(size=40) # Fallback
 
 
 
 
 
 
 
 
 
27
 
28
- # Calculate text position (bottom-right corner)
29
  bbox = draw.textbbox((0, 0), WATERMARK_TEXT, font=font)
30
- text_width = bbox[2] - bbox[0]
31
- text_height = bbox[3] - bbox[1]
32
- margin = 20
33
- position = (image.width - text_width - margin, image.height - text_height - margin)
 
 
 
 
 
 
 
34
 
35
- # Draw watermark with semi-transparent white text and black outline
36
  draw.text(
37
  position,
38
  WATERMARK_TEXT,
39
  font=font,
40
- fill=(255, 255, 255, 128),
41
- stroke_width=2,
42
- stroke_fill=(0, 0, 0, 128)
43
- )
44
 
45
  return image
46
  except Exception as e:
47
  print(f"Watermark error: {str(e)}")
48
- return Image.open(io.BytesIO(image_bytes)) # Return original if watermark fails
49
 
 
50
  def generate_image(prompt):
51
- """Generate image with retry logic"""
52
  if not prompt.strip():
53
- return "Error: Please enter a valid prompt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  for attempt in range(MAX_RETRIES):
56
  try:
57
- response = requests.post(
58
- API_URL,
59
- headers=headers,
60
- json={"inputs": prompt, "options": {"wait_for_model": True}},
61
- timeout=30
62
- )
63
 
64
  if response.status_code == 200:
65
- return add_watermark(response.content)
66
- elif response.status_code == 503: # Model loading
67
- time.sleep(10 * (attempt + 1)) # Exponential backoff
 
 
68
  continue
69
  else:
70
- return f"API Error: {response.text}"
71
  except requests.Timeout:
72
- return "Error: Request timed out (30s)"
73
  except Exception as e:
74
- return f"Unexpected error: {str(e)}"
75
 
76
- return "Failed after multiple attempts. Try again later."
 
 
 
 
 
 
 
77
 
78
- # Gradio Interface
79
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
80
  gr.Markdown("""
81
  # 🎨 SelamGPT Image Generator
82
- *Powered by Stable Diffusion XL with built-in watermark*
83
  """)
84
 
85
  with gr.Row():
86
- with gr.Column():
87
  prompt_input = gr.Textbox(
88
  label="Describe your image",
89
- placeholder="A futuristic city at sunset...",
90
- lines=3
 
 
91
  )
92
- generate_btn = gr.Button("Generate", variant="primary")
93
- examples = gr.Examples(
94
- examples=["A cute robot reading a book", "Ethiopian landscape in oil painting style"],
95
- inputs=prompt_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
-
98
- with gr.Column():
99
- output_image = gr.Image(label="Generated Image", height=512)
100
- error_output = gr.Textbox(label="Status", visible=False)
101
 
 
102
  generate_btn.click(
103
  fn=generate_image,
104
  inputs=prompt_input,
105
- outputs=[output_image, error_output],
106
- show_progress="full"
 
 
 
 
 
 
107
  )
108
 
 
109
  if __name__ == "__main__":
110
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
4
  from PIL import Image, ImageDraw, ImageFont
5
  import io
6
  import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
 
9
+ # ===== CONFIGURATION =====
10
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
11
+ MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
12
  API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
13
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
14
  WATERMARK_TEXT = "SelamGPT"
15
  MAX_RETRIES = 3
16
+ TIMEOUT = 45 # Increased timeout for larger images
17
+ EXECUTOR = ThreadPoolExecutor(max_workers=2) # Handle concurrent requests
18
 
19
+ # ===== WATERMARK FUNCTION =====
20
  def add_watermark(image_bytes):
21
+ """Add professional watermark with fallback fonts"""
22
  try:
23
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
  draw = ImageDraw.Draw(image)
25
 
26
+ # Try multiple font options
27
+ font = None
28
+ for font_path in [
29
+ "Roboto-Bold.ttf", # Our Dockerfile installs this
30
+ "DejaVuSans-Bold.ttf",
31
+ "FreeSansBold.ttf",
32
+ None # Final fallback to default
33
+ ]:
34
+ try:
35
+ size = min(image.width // 15, 40) # Responsive sizing
36
+ font = ImageFont.truetype(font_path, size) if font_path else ImageFont.load_default(size)
37
+ break
38
+ except:
39
+ continue
40
 
41
+ # Calculate dynamic position
42
  bbox = draw.textbbox((0, 0), WATERMARK_TEXT, font=font)
43
+ text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
44
+ margin = image.width // 50
45
+ position = (image.width - text_w - margin, image.height - text_h - margin)
46
+
47
+ # Draw with outline effect
48
+ for offset in [(-1,-1), (1,1)]: # Shadow positions
49
+ draw.text(
50
+ (position[0]+offset[0], position[1]+offset[1]),
51
+ WATERMARK_TEXT,
52
+ font=font,
53
+ fill=(0, 0, 0, 180) # Semi-transparent black
54
 
 
55
  draw.text(
56
  position,
57
  WATERMARK_TEXT,
58
  font=font,
59
+ fill=(255, 255, 255, 200)) # Semi-transparent white
 
 
 
60
 
61
  return image
62
  except Exception as e:
63
  print(f"Watermark error: {str(e)}")
64
+ return Image.open(io.BytesIO(image_bytes)) # Fallback to original
65
 
66
+ # ===== IMAGE GENERATION =====
67
  def generate_image(prompt):
68
+ """Generate image with robust error handling"""
69
  if not prompt.strip():
70
+ return None, "⚠️ Please enter a prompt"
71
+
72
+ def api_call():
73
+ return requests.post(
74
+ API_URL,
75
+ headers=headers,
76
+ json={
77
+ "inputs": prompt,
78
+ "parameters": {
79
+ "height": 768,
80
+ "width": 768,
81
+ "num_inference_steps": 25
82
+ },
83
+ "options": {"wait_for_model": True}
84
+ },
85
+ timeout=TIMEOUT
86
+ )
87
 
88
  for attempt in range(MAX_RETRIES):
89
  try:
90
+ future = EXECUTOR.submit(api_call)
91
+ response = future.result()
 
 
 
 
92
 
93
  if response.status_code == 200:
94
+ return add_watermark(response.content), "✔️ Generation successful"
95
+ elif response.status_code == 503:
96
+ wait_time = (attempt + 1) * 10
97
+ print(f"Model loading, waiting {wait_time}s...")
98
+ time.sleep(wait_time)
99
  continue
100
  else:
101
+ return None, f"⚠️ API Error: {response.text[:200]}"
102
  except requests.Timeout:
103
+ return None, "⚠️ Timeout: Model took too long to respond"
104
  except Exception as e:
105
+ return None, f"⚠️ Unexpected error: {str(e)[:200]}"
106
 
107
+ return None, "⚠️ Failed after multiple attempts. Please try again later."
108
+
109
+ # ===== GRADIO INTERFACE =====
110
+ theme = gr.themes.Default(
111
+ primary_hue="emerald",
112
+ secondary_hue="amber",
113
+ font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"]
114
+ )
115
 
116
+ with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
 
117
  gr.Markdown("""
118
  # 🎨 SelamGPT Image Generator
119
+ *Generate watermarked images with Stable Diffusion XL*
120
  """)
121
 
122
  with gr.Row():
123
+ with gr.Column(scale=3):
124
  prompt_input = gr.Textbox(
125
  label="Describe your image",
126
+ placeholder="A futuristic Ethiopian city with flying cars...",
127
+ lines=3,
128
+ max_lines=5,
129
+ elem_id="prompt-box"
130
  )
131
+ with gr.Row():
132
+ generate_btn = gr.Button("Generate Image", variant="primary")
133
+ clear_btn = gr.Button("Clear")
134
+
135
+ gr.Examples(
136
+ examples=[
137
+ ["An ancient Aksumite warrior in cyberpunk armor"],
138
+ ["Traditional Ethiopian coffee ceremony in space"],
139
+ ["Hyper-realistic portrait of a Habesha woman with neon tribal markings"]
140
+ ],
141
+ inputs=prompt_input,
142
+ label="Example Prompts"
143
+ )
144
+
145
+ with gr.Column(scale=2):
146
+ output_image = gr.Image(
147
+ label="Generated Image",
148
+ height=512,
149
+ elem_id="output-image"
150
+ )
151
+ status_output = gr.Textbox(
152
+ label="Status",
153
+ interactive=False,
154
+ elem_id="status-box"
155
  )
 
 
 
 
156
 
157
+ # Event handlers
158
  generate_btn.click(
159
  fn=generate_image,
160
  inputs=prompt_input,
161
+ outputs=[output_image, status_output],
162
+ queue=True,
163
+ show_progress="minimal"
164
+ )
165
+
166
+ clear_btn.click(
167
+ fn=lambda: [None, ""],
168
+ outputs=[output_image, status_output]
169
  )
170
 
171
+ # ===== DEPLOYMENT CONFIG =====
172
  if __name__ == "__main__":
173
+ demo.queue(concurrency_count=2, api_open=False)
174
+ demo.launch(
175
+ server_name="0.0.0.0",
176
+ server_port=7860,
177
+ favicon_path="./favicon.ico" # Optional
178
+ )