snackshell commited on
Commit
bc1c1c7
·
verified ·
1 Parent(s): e3e741b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -78
app.py CHANGED
@@ -1,101 +1,117 @@
 
 
1
  import gradio as gr
2
  from PIL import Image, ImageDraw, ImageFont
3
  import io
4
- import torch
5
- from diffusers import DiffusionPipeline
6
 
7
  # ===== CONFIGURATION =====
8
- MODEL_NAME = "HiDream-ai/HiDream-I1-Full"
 
 
 
9
  WATERMARK_TEXT = "SelamGPT"
10
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
- TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
12
-
13
- def load_model():
14
- pipe = DiffusionPipeline.from_pretrained(
15
- MODEL_NAME,
16
- torch_dtype=TORCH_DTYPE
17
- ).to(DEVICE)
18
-
19
- # Optimizations
20
- if DEVICE == "cuda":
21
- try:
22
- pipe.enable_xformers_memory_efficient_attention()
23
- except:
24
- print("Xformers not available, using default attention")
25
- pipe.enable_attention_slicing()
26
-
27
- return pipe
28
 
29
  # ===== WATERMARK FUNCTION =====
30
- def add_watermark(image):
31
  """Add watermark with optimized PNG output"""
32
  try:
 
33
  draw = ImageDraw.Draw(image)
34
 
35
- font_size = max(24, int(image.width * 0.03)) # Dynamic font sizing
36
  try:
37
  font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
38
  except:
39
  font = ImageFont.load_default(font_size)
40
 
41
  text_width = draw.textlength(WATERMARK_TEXT, font=font)
42
- margin = image.width * 0.02 # Dynamic margin
43
- x = image.width - text_width - margin
44
- y = image.height - (font_size * 1.5)
45
 
46
- # Shadow effect
47
- draw.text((x+2, y+2), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 150))
48
- # Main text
49
- draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 215, 0)) # Gold color
50
 
51
- # Optimized PNG output
52
  img_byte_arr = io.BytesIO()
53
- image.save(img_byte_arr, format='PNG', optimize=True)
 
54
  return Image.open(img_byte_arr)
55
  except Exception as e:
56
  print(f"Watermark error: {str(e)}")
57
- return image
58
 
59
  # ===== IMAGE GENERATION =====
60
  def generate_image(prompt):
61
  if not prompt.strip():
62
- raise gr.Error("Please enter a prompt")
63
 
64
- try:
65
- model = load_model()
66
- result = model(
67
- prompt,
68
- num_inference_steps=30,
69
- guidance_scale=7.5,
70
- width=1024,
71
- height=1024
 
 
 
 
 
 
72
  )
73
- return add_watermark(result.images[0]), "🎨 Generation complete!"
74
 
75
- except torch.cuda.OutOfMemoryError:
76
- raise gr.Error("Out of memory! Try a simpler prompt or smaller image size")
77
- except Exception as e:
78
- raise gr.Error(f"Generation failed: {str(e)[:200]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # ===== GRADIO 5.x INTERFACE =====
81
- with gr.Blocks(theme=gr.themes.Default(
82
  primary_hue="emerald",
83
- secondary_hue="gold",
84
  font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"]
85
- )) as demo:
86
-
87
- gr.Markdown("""<h1 align="center">🎨 SelamGPT HiDream Generator</h1>""")
 
 
 
 
 
88
 
89
- with gr.Row(variant="panel"):
90
  with gr.Column(scale=3):
91
  prompt_input = gr.Textbox(
92
  label="Describe your image",
93
  placeholder="A futuristic Ethiopian city with flying cars...",
94
  lines=3,
95
- max_lines=5,
96
- autofocus=True
97
  )
98
- generate_btn = gr.Button("Generate Image", variant="primary")
 
 
99
 
100
  gr.Examples(
101
  examples=[
@@ -103,41 +119,33 @@ with gr.Blocks(theme=gr.themes.Default(
103
  ["Traditional Ethiopian coffee ceremony in zero gravity"],
104
  ["Portrait of a Habesha queen with golden jewelry"]
105
  ],
106
- inputs=prompt_input,
107
- label="Try these prompts:"
108
  )
109
-
110
  with gr.Column(scale=2):
111
  output_image = gr.Image(
112
  label="Generated Image",
113
  type="pil",
114
- height=512,
115
- interactive=False
116
  )
117
- status = gr.Textbox(
118
  label="Status",
119
- interactive=False,
120
- show_label=False
121
  )
122
 
123
- # Event handlers
124
  generate_btn.click(
125
  fn=generate_image,
126
  inputs=prompt_input,
127
- outputs=[output_image, status],
128
- api_name="generate"
129
  )
130
 
131
- # Keyboard shortcut
132
- prompt_input.submit(
133
- fn=generate_image,
134
- inputs=prompt_input,
135
- outputs=[output_image, status]
136
  )
137
 
138
  if __name__ == "__main__":
139
- demo.launch(
140
- server_name="0.0.0.0",
141
- server_port=7860,
142
- share=False
143
- )
 
1
+ import os
2
+ import requests
3
  import gradio as gr
4
  from PIL import Image, ImageDraw, ImageFont
5
  import io
6
+ import time
7
+ from concurrent.futures import ThreadPoolExecutor
8
 
9
  # ===== CONFIGURATION =====
10
+ HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
11
+ MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
12
+ API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
13
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
14
  WATERMARK_TEXT = "SelamGPT"
15
+ MAX_RETRIES = 3
16
+ TIMEOUT = 60
17
+ EXECUTOR = ThreadPoolExecutor(max_workers=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ===== WATERMARK FUNCTION =====
20
+ def add_watermark(image_bytes):
21
  """Add watermark with optimized PNG output"""
22
  try:
23
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
  draw = ImageDraw.Draw(image)
25
 
26
+ font_size = 24
27
  try:
28
  font = ImageFont.truetype("Roboto-Bold.ttf", font_size)
29
  except:
30
  font = ImageFont.load_default(font_size)
31
 
32
  text_width = draw.textlength(WATERMARK_TEXT, font=font)
33
+ x = image.width - text_width - 10
34
+ y = image.height - 34
 
35
 
36
+ draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128))
37
+ draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255))
 
 
38
 
39
+ # Convert to optimized PNG
40
  img_byte_arr = io.BytesIO()
41
+ image.save(img_byte_arr, format='PNG', optimize=True, quality=85)
42
+ img_byte_arr.seek(0)
43
  return Image.open(img_byte_arr)
44
  except Exception as e:
45
  print(f"Watermark error: {str(e)}")
46
+ return Image.open(io.BytesIO(image_bytes))
47
 
48
  # ===== IMAGE GENERATION =====
49
  def generate_image(prompt):
50
  if not prompt.strip():
51
+ return None, "⚠️ Please enter a prompt"
52
 
53
+ def api_call():
54
+ return requests.post(
55
+ API_URL,
56
+ headers=headers,
57
+ json={
58
+ "inputs": prompt,
59
+ "parameters": {
60
+ "height": 1024,
61
+ "width": 1024,
62
+ "num_inference_steps": 30
63
+ },
64
+ "options": {"wait_for_model": True}
65
+ },
66
+ timeout=TIMEOUT
67
  )
 
68
 
69
+ for attempt in range(MAX_RETRIES):
70
+ try:
71
+ future = EXECUTOR.submit(api_call)
72
+ response = future.result()
73
+
74
+ if response.status_code == 200:
75
+ return add_watermark(response.content), "✔️ Generation successful"
76
+ elif response.status_code == 503:
77
+ wait_time = (attempt + 1) * 15
78
+ print(f"Model loading, waiting {wait_time}s...")
79
+ time.sleep(wait_time)
80
+ continue
81
+ else:
82
+ return None, f"⚠️ API Error: {response.text[:200]}"
83
+ except requests.Timeout:
84
+ return None, f"⚠️ Timeout: Model took >{TIMEOUT}s to respond"
85
+ except Exception as e:
86
+ return None, f"⚠️ Unexpected error: {str(e)[:200]}"
87
+
88
+ return None, "⚠️ Failed after multiple attempts. Please try later."
89
 
90
+ # ===== GRADIO THEME =====
91
+ theme = gr.themes.Default(
92
  primary_hue="emerald",
93
+ secondary_hue="amber",
94
  font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"]
95
+ )
96
+
97
+ # ===== GRADIO INTERFACE =====
98
+ with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
99
+ gr.Markdown("""
100
+ # 🎨 SelamGPT Image Generator
101
+ *Powered by Stable Diffusion XL (1024x1024 PNG output)*
102
+ """)
103
 
104
+ with gr.Row():
105
  with gr.Column(scale=3):
106
  prompt_input = gr.Textbox(
107
  label="Describe your image",
108
  placeholder="A futuristic Ethiopian city with flying cars...",
109
  lines=3,
110
+ max_lines=5
 
111
  )
112
+ with gr.Row():
113
+ generate_btn = gr.Button("Generate Image", variant="primary")
114
+ clear_btn = gr.Button("Clear")
115
 
116
  gr.Examples(
117
  examples=[
 
119
  ["Traditional Ethiopian coffee ceremony in zero gravity"],
120
  ["Portrait of a Habesha queen with golden jewelry"]
121
  ],
122
+ inputs=prompt_input
 
123
  )
124
+
125
  with gr.Column(scale=2):
126
  output_image = gr.Image(
127
  label="Generated Image",
128
  type="pil",
129
+ format="png",
130
+ height=512
131
  )
132
+ status_output = gr.Textbox(
133
  label="Status",
134
+ interactive=False
 
135
  )
136
 
 
137
  generate_btn.click(
138
  fn=generate_image,
139
  inputs=prompt_input,
140
+ outputs=[output_image, status_output],
141
+ queue=True
142
  )
143
 
144
+ clear_btn.click(
145
+ fn=lambda: [None, ""],
146
+ outputs=[output_image, status_output]
 
 
147
  )
148
 
149
  if __name__ == "__main__":
150
+ demo.queue(max_size=2)
151
+ demo.launch(server_name="0.0.0.0", server_port=7860)