snackshell commited on
Commit
df1e443
·
verified ·
1 Parent(s): aca8344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -49
app.py CHANGED
@@ -1,26 +1,37 @@
1
  import os
2
- import requests
3
  import gradio as gr
4
  from PIL import Image, ImageDraw, ImageFont
5
  import io
6
- import time
7
- from concurrent.futures import ThreadPoolExecutor
8
 
9
  # ===== CONFIGURATION =====
10
- HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
11
- MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
12
- API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}"
13
- headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
14
  WATERMARK_TEXT = "SelamGPT"
15
- MAX_RETRIES = 3
16
- TIMEOUT = 60
17
- EXECUTOR = ThreadPoolExecutor(max_workers=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ===== WATERMARK FUNCTION =====
20
- def add_watermark(image_bytes):
21
  """Add watermark with optimized PNG output"""
22
  try:
23
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
  draw = ImageDraw.Draw(image)
25
 
26
  font_size = 24
@@ -43,49 +54,29 @@ def add_watermark(image_bytes):
43
  return Image.open(img_byte_arr)
44
  except Exception as e:
45
  print(f"Watermark error: {str(e)}")
46
- return Image.open(io.BytesIO(image_bytes))
47
 
48
  # ===== IMAGE GENERATION =====
49
  def generate_image(prompt):
50
  if not prompt.strip():
51
  return None, "⚠️ Please enter a prompt"
52
 
53
- def api_call():
54
- return requests.post(
55
- API_URL,
56
- headers=headers,
57
- json={
58
- "inputs": prompt,
59
- "parameters": {
60
- "height": 1024,
61
- "width": 1024,
62
- "num_inference_steps": 30
63
- },
64
- "options": {"wait_for_model": True}
65
- },
66
- timeout=TIMEOUT
67
- )
68
-
69
- for attempt in range(MAX_RETRIES):
70
- try:
71
- future = EXECUTOR.submit(api_call)
72
- response = future.result()
73
-
74
- if response.status_code == 200:
75
- return add_watermark(response.content), "✔️ Generation successful"
76
- elif response.status_code == 503:
77
- wait_time = (attempt + 1) * 15
78
- print(f"Model loading, waiting {wait_time}s...")
79
- time.sleep(wait_time)
80
- continue
81
- else:
82
- return None, f"⚠️ API Error: {response.text[:200]}"
83
- except requests.Timeout:
84
- return None, f"⚠️ Timeout: Model took >{TIMEOUT}s to respond"
85
- except Exception as e:
86
- return None, f"⚠️ Unexpected error: {str(e)[:200]}"
87
 
88
- return None, "⚠️ Failed after multiple attempts. Please try later."
 
 
 
89
 
90
  # ===== GRADIO THEME =====
91
  theme = gr.themes.Default(
@@ -98,7 +89,7 @@ theme = gr.themes.Default(
98
  with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
99
  gr.Markdown("""
100
  # 🎨 SelamGPT Image Generator
101
- *Powered by Stable Diffusion XL (1024x1024 PNG output)*
102
  """)
103
 
104
  with gr.Row():
 
1
  import os
 
2
  import gradio as gr
3
  from PIL import Image, ImageDraw, ImageFont
4
  import io
5
+ import torch
6
+ from diffusers import DiffusionPipeline
7
 
8
  # ===== CONFIGURATION =====
9
+ MODEL_NAME = "HiDream-ai/HiDream-I1-Full"
 
 
 
10
  WATERMARK_TEXT = "SelamGPT"
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ TORCH_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
13
+
14
+ # ===== MODEL LOADING =====
15
+ @gr.Cache() # Cache model between generations
16
+ def load_model():
17
+ pipe = DiffusionPipeline.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=TORCH_DTYPE
20
+ ).to(DEVICE)
21
+
22
+ # Optimizations
23
+ if DEVICE == "cuda":
24
+ pipe.enable_xformers_memory_efficient_attention()
25
+ pipe.enable_attention_slicing()
26
+
27
+ return pipe
28
+
29
+ pipe = load_model()
30
 
31
  # ===== WATERMARK FUNCTION =====
32
+ def add_watermark(image):
33
  """Add watermark with optimized PNG output"""
34
  try:
 
35
  draw = ImageDraw.Draw(image)
36
 
37
  font_size = 24
 
54
  return Image.open(img_byte_arr)
55
  except Exception as e:
56
  print(f"Watermark error: {str(e)}")
57
+ return image
58
 
59
  # ===== IMAGE GENERATION =====
60
  def generate_image(prompt):
61
  if not prompt.strip():
62
  return None, "⚠️ Please enter a prompt"
63
 
64
+ try:
65
+ # Generate image (1024x1024 by default)
66
+ image = pipe(
67
+ prompt,
68
+ num_inference_steps=30,
69
+ guidance_scale=7.5
70
+ ).images[0]
71
+
72
+ # Add watermark
73
+ watermarked = add_watermark(image)
74
+ return watermarked, "✔️ Generation successful"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ except torch.cuda.OutOfMemoryError:
77
+ return None, "⚠️ Out of memory! Try a simpler prompt"
78
+ except Exception as e:
79
+ return None, f"⚠️ Error: {str(e)[:200]}"
80
 
81
  # ===== GRADIO THEME =====
82
  theme = gr.themes.Default(
 
89
  with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo:
90
  gr.Markdown("""
91
  # 🎨 SelamGPT Image Generator
92
+ *Powered by HiDream-I1-Full (1024x1024 PNG output)*
93
  """)
94
 
95
  with gr.Row():