snackshell commited on
Commit
5d4b17c
·
verified ·
1 Parent(s): e86d5f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -48
app.py CHANGED
@@ -8,28 +8,21 @@ from concurrent.futures import ThreadPoolExecutor
8
 
9
  # ===== CONFIGURATION =====
10
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
11
- MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" # Using SDXL
12
  API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
13
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
14
  WATERMARK_TEXT = "SelamGPT"
15
  MAX_RETRIES = 3
16
- TIMEOUT = 60 # Increased for SDXL's longer processing
17
  EXECUTOR = ThreadPoolExecutor(max_workers=2)
18
 
19
  # ===== WATERMARK FUNCTION =====
20
  def add_watermark(image_bytes):
21
- """Convert to PNG with medium quality before watermarking"""
22
  try:
23
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
24
 
25
- # Save as medium-quality PNG to buffer
26
- png_buffer = io.BytesIO()
27
- image.save(png_buffer, format="PNG", optimize=True, quality=85) # Medium quality
28
- png_buffer.seek(0)
29
-
30
- # Add watermark to the PNG
31
- watermarked_image = Image.open(png_buffer)
32
- draw = ImageDraw.Draw(watermarked_image)
33
  font_size = 24
34
  try:
35
  font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
@@ -37,25 +30,23 @@ def add_watermark(image_bytes):
37
  font = ImageFont.load_default(font_size)
38
 
39
  text_width = draw.textlength(WATERMARK_TEXT, font=font)
40
- x = watermarked_image.width - text_width - 10
41
- y = watermarked_image.height - 34
42
 
43
  draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128))
44
  draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255))
45
 
46
- # Return as PNG bytes
47
- final_buffer = io.BytesIO()
48
- watermarked_image.save(final_buffer, format="PNG", optimize=True, quality=85)
49
- final_buffer.seek(0)
50
- return Image.open(final_buffer)
51
-
52
  except Exception as e:
53
  print(f"Watermark error: {str(e)}")
54
  return Image.open(io.BytesIO(image_bytes))
55
 
56
- # ===== IMAGE GENERATION (SDXL-OPTIMIZED) =====
57
  def generate_image(prompt):
58
- """Generate image with SDXL-specific parameters"""
59
  if not prompt.strip():
60
  return None, "⚠️ Please enter a prompt"
61
 
@@ -66,10 +57,9 @@ def generate_image(prompt):
66
  json={
67
  "inputs": prompt,
68
  "parameters": {
69
- "height": 1024, # SDXL's native resolution
70
  "width": 1024,
71
- "num_inference_steps": 30, # Better quality than 25
72
- "guidance_scale": 7.5 # SDXL's optimal value
73
  },
74
  "options": {"wait_for_model": True}
75
  },
@@ -84,7 +74,7 @@ def generate_image(prompt):
84
  if response.status_code == 200:
85
  return add_watermark(response.content), "✔️ Generation successful"
86
  elif response.status_code == 503:
87
- wait_time = (attempt + 1) * 15 # Longer wait for SDXL
88
  print(f"Model loading, waiting {wait_time}s...")
89
  time.sleep(wait_time)
90
  continue
@@ -97,20 +87,18 @@ def generate_image(prompt):
97
 
98
  return None, "⚠️ Failed after multiple attempts. Please try later."
99
 
100
- # ===== GRADIO INTERFACE =====
101
- with gr.Blocks() as demo:
102
-
103
- output_image = gr.Image(
104
- label="Generated Image",
105
- type="pil", # Force PIL/PNG output
106
- format="png", # Explicit PNG format
107
- height=512
108
- )
109
 
 
110
  with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
111
  gr.Markdown("""
112
  # 🎨 SelamGPT Image Generator
113
- *Now powered by Stable Diffusion XL (1024x1024 resolution)*
114
  """)
115
 
116
  with gr.Row():
@@ -119,8 +107,7 @@ with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
119
  label="Describe your image",
120
  placeholder="A futuristic Ethiopian city with flying cars...",
121
  lines=3,
122
- max_lines=5,
123
- elem_id="prompt-box"
124
  )
125
  with gr.Row():
126
  generate_btn = gr.Button("Generate Image", variant="primary")
@@ -129,31 +116,29 @@ with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
129
  gr.Examples(
130
  examples=[
131
  ["An ancient Aksumite warrior in cyberpunk armor, 4k detailed"],
132
- ["Traditional Ethiopian coffee ceremony in zero gravity, photorealistic"],
133
- ["Portrait of a Habesha queen with golden jewelry, studio lighting"]
134
  ],
135
- inputs=prompt_input,
136
- label="Try these SDXL-optimized prompts:"
137
  )
138
 
139
  with gr.Column(scale=2):
140
  output_image = gr.Image(
141
- label="Generated Image (1024x1024)",
142
- height=512,
143
- elem_id="output-image"
 
144
  )
145
  status_output = gr.Textbox(
146
  label="Status",
147
- interactive=False,
148
- elem_id="status-box"
149
  )
150
 
151
  generate_btn.click(
152
  fn=generate_image,
153
  inputs=prompt_input,
154
  outputs=[output_image, status_output],
155
- queue=True,
156
- show_progress="minimal"
157
  )
158
 
159
  clear_btn.click(
 
8
 
9
  # ===== CONFIGURATION =====
10
  HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
11
+ MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
12
  API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
13
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
14
  WATERMARK_TEXT = "SelamGPT"
15
  MAX_RETRIES = 3
16
+ TIMEOUT = 60
17
  EXECUTOR = ThreadPoolExecutor(max_workers=2)
18
 
19
  # ===== WATERMARK FUNCTION =====
20
  def add_watermark(image_bytes):
21
+ """Add watermark with optimized PNG output"""
22
  try:
23
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
+ draw = ImageDraw.Draw(image)
25
 
 
 
 
 
 
 
 
 
26
  font_size = 24
27
  try:
28
  font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
 
30
  font = ImageFont.load_default(font_size)
31
 
32
  text_width = draw.textlength(WATERMARK_TEXT, font=font)
33
+ x = image.width - text_width - 10
34
+ y = image.height - 34
35
 
36
  draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128))
37
  draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255))
38
 
39
+ # Convert to optimized PNG
40
+ img_byte_arr = io.BytesIO()
41
+ image.save(img_byte_arr, format='PNG', optimize=True, quality=85)
42
+ img_byte_arr.seek(0)
43
+ return Image.open(img_byte_arr)
 
44
  except Exception as e:
45
  print(f"Watermark error: {str(e)}")
46
  return Image.open(io.BytesIO(image_bytes))
47
 
48
+ # ===== IMAGE GENERATION =====
49
  def generate_image(prompt):
 
50
  if not prompt.strip():
51
  return None, "⚠️ Please enter a prompt"
52
 
 
57
  json={
58
  "inputs": prompt,
59
  "parameters": {
60
+ "height": 1024,
61
  "width": 1024,
62
+ "num_inference_steps": 30
 
63
  },
64
  "options": {"wait_for_model": True}
65
  },
 
74
  if response.status_code == 200:
75
  return add_watermark(response.content), "✔️ Generation successful"
76
  elif response.status_code == 503:
77
+ wait_time = (attempt + 1) * 15
78
  print(f"Model loading, waiting {wait_time}s...")
79
  time.sleep(wait_time)
80
  continue
 
87
 
88
  return None, "⚠️ Failed after multiple attempts. Please try later."
89
 
90
+ # ===== GRADIO THEME =====
91
+ theme = gr.themes.Default(
92
+ primary_hue="emerald",
93
+ secondary_hue="amber",
94
+ font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"]
95
+ )
 
 
 
96
 
97
+ # ===== GRADIO INTERFACE =====
98
  with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
99
  gr.Markdown("""
100
  # 🎨 SelamGPT Image Generator
101
+ *Powered by Stable Diffusion XL (1024x1024 PNG output)*
102
  """)
103
 
104
  with gr.Row():
 
107
  label="Describe your image",
108
  placeholder="A futuristic Ethiopian city with flying cars...",
109
  lines=3,
110
+ max_lines=5
 
111
  )
112
  with gr.Row():
113
  generate_btn = gr.Button("Generate Image", variant="primary")
 
116
  gr.Examples(
117
  examples=[
118
  ["An ancient Aksumite warrior in cyberpunk armor, 4k detailed"],
119
+ ["Traditional Ethiopian coffee ceremony in zero gravity"],
120
+ ["Portrait of a Habesha queen with golden jewelry"]
121
  ],
122
+ inputs=prompt_input
 
123
  )
124
 
125
  with gr.Column(scale=2):
126
  output_image = gr.Image(
127
+ label="Generated Image",
128
+ type="pil",
129
+ format="png",
130
+ height=512
131
  )
132
  status_output = gr.Textbox(
133
  label="Status",
134
+ interactive=False
 
135
  )
136
 
137
  generate_btn.click(
138
  fn=generate_image,
139
  inputs=prompt_input,
140
  outputs=[output_image, status_output],
141
+ queue=True
 
142
  )
143
 
144
  clear_btn.click(