awacke1 commited on
Commit
d67b9b4
·
verified ·
1 Parent(s): 1b79bfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -88
app.py CHANGED
@@ -7,18 +7,55 @@ from PIL import Image
7
  import io
8
  import requests
9
  from huggingface_hub import HfApi, login
 
 
10
 
11
- # Initialize session state - must be first
12
- if 'hf_token' not in st.session_state:
13
- st.session_state['hf_token'] = None
14
- if 'is_authenticated' not in st.session_state:
15
- st.session_state['is_authenticated'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class ModelGenerator:
18
- @staticmethod
19
- def generate_midjourney(prompt, token):
 
 
20
  try:
21
- client = Client("mukaist/Midjourney", hf_token=token)
22
  result = client.predict(
23
  prompt=prompt,
24
  negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
@@ -37,21 +74,18 @@ class ModelGenerator:
37
  if isinstance(image_data, str):
38
  if image_data.startswith('http'):
39
  response = requests.get(image_data)
40
- image = Image.open(io.BytesIO(response.content))
41
  else:
42
- image = Image.open(image_data)
43
  else:
44
- image = Image.open(io.BytesIO(image_data))
45
- return ("Midjourney", image)
46
- else:
47
- return ("Midjourney", f"Error: Unexpected result format: {type(result)}")
48
  except Exception as e:
49
  return ("Midjourney", f"Error: {str(e)}")
50
 
51
- @staticmethod
52
- def generate_stable_cascade(prompt, token):
53
  try:
54
- client = Client("multimodalart/stable-cascade", hf_token=token)
55
  result = client.predict(
56
  prompt=prompt,
57
  negative_prompt=prompt,
@@ -65,14 +99,23 @@ class ModelGenerator:
65
  num_images_per_prompt=1,
66
  api_name="/run"
67
  )
68
- return ("Stable Cascade", result)
 
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
  return ("Stable Cascade", f"Error: {str(e)}")
71
 
72
- @staticmethod
73
- def generate_stable_diffusion_3(prompt, token):
74
  try:
75
- client = Client("stabilityai/stable-diffusion-3-medium", hf_token=token)
76
  result = client.predict(
77
  prompt=prompt,
78
  negative_prompt=prompt,
@@ -84,14 +127,15 @@ class ModelGenerator:
84
  num_inference_steps=28,
85
  api_name="/infer"
86
  )
87
- return ("SD 3 Medium", result)
 
 
88
  except Exception as e:
89
  return ("SD 3 Medium", f"Error: {str(e)}")
90
 
91
- @staticmethod
92
- def generate_stable_diffusion_35(prompt, token):
93
  try:
94
- client = Client("stabilityai/stable-diffusion-3.5-large", hf_token=token)
95
  result = client.predict(
96
  prompt=prompt,
97
  negative_prompt=prompt,
@@ -103,15 +147,15 @@ class ModelGenerator:
103
  num_inference_steps=40,
104
  api_name="/infer"
105
  )
106
- return ("SD 3.5 Large", result)
 
 
107
  except Exception as e:
108
  return ("SD 3.5 Large", f"Error: {str(e)}")
109
 
110
- @staticmethod
111
- def generate_playground_v2_5(prompt, token):
112
  try:
113
- client = Client("https://playgroundai-playground-v2-5.hf.space/--replicas/ji5gy/",
114
- hf_token=token)
115
  result = client.predict(
116
  prompt,
117
  prompt, # negative prompt
@@ -123,26 +167,30 @@ class ModelGenerator:
123
  True, # randomize seed
124
  api_name="/run"
125
  )
126
- if result and isinstance(result, tuple) and result[0]:
127
- return ("Playground v2.5", result[0][0]['image'])
 
 
 
 
 
 
 
128
  return ("Playground v2.5", "Error: No image generated")
129
  except Exception as e:
130
  return ("Playground v2.5", f"Error: {str(e)}")
131
 
132
- def generate_images(prompt, selected_models):
133
- token = st.session_state.get('hf_token')
134
- if not token:
135
- return [("Error", "No authentication token found")]
136
-
137
  results = []
138
  with concurrent.futures.ThreadPoolExecutor() as executor:
139
  futures = []
 
140
  model_map = {
141
- "Midjourney": lambda p: ModelGenerator.generate_midjourney(p, token),
142
- "Stable Cascade": lambda p: ModelGenerator.generate_stable_cascade(p, token),
143
- "SD 3 Medium": lambda p: ModelGenerator.generate_stable_diffusion_3(p, token),
144
- "SD 3.5 Large": lambda p: ModelGenerator.generate_stable_diffusion_35(p, token),
145
- "Playground v2.5": lambda p: ModelGenerator.generate_playground_v2_5(p, token)
146
  }
147
 
148
  for model in selected_models:
@@ -150,17 +198,20 @@ def generate_images(prompt, selected_models):
150
  futures.append(executor.submit(model_map[model], prompt))
151
 
152
  for future in concurrent.futures.as_completed(futures):
153
- results.append(future.result())
 
 
 
 
 
154
 
155
  return results
156
 
157
  def handle_prompt_click(prompt_text, key):
158
- if not st.session_state.get('is_authenticated') or not st.session_state.get('hf_token'):
159
  st.error("Please login with your HuggingFace account first!")
160
  return
161
 
162
- st.session_state[f'selected_prompt_{key}'] = prompt_text
163
-
164
  selected_models = st.session_state.get('selected_models', [])
165
 
166
  if not selected_models:
@@ -168,42 +219,31 @@ def handle_prompt_click(prompt_text, key):
168
  return
169
 
170
  with st.spinner('Generating artwork...'):
171
- results = generate_images(prompt_text, selected_models)
172
- st.session_state[f'generated_images_{key}'] = results
173
- st.success("Artwork generated successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  def main():
176
  st.title("🎨 Multi-Model Art Generator")
 
 
 
177
 
178
- # Handle authentication in sidebar
179
- with st.sidebar:
180
- st.header("🔐 Authentication")
181
- if st.session_state.get('is_authenticated') and st.session_state.get('hf_token'):
182
- st.success("✓ Logged in to HuggingFace")
183
- if st.button("Logout"):
184
- st.session_state['hf_token'] = None
185
- st.session_state['is_authenticated'] = False
186
- st.rerun()
187
- else:
188
- token = st.text_input("Enter HuggingFace Token", type="password",
189
- help="Get your token from https://huggingface.co/settings/tokens")
190
- if st.button("Login"):
191
- if token:
192
- try:
193
- # Verify token is valid
194
- api = HfApi(token=token)
195
- api.whoami()
196
- st.session_state['hf_token'] = token
197
- st.session_state['is_authenticated'] = True
198
- st.success("Successfully logged in!")
199
- st.rerun()
200
- except Exception as e:
201
- st.error(f"Authentication failed: {str(e)}")
202
- else:
203
- st.error("Please enter your HuggingFace token")
204
-
205
- if st.session_state.get('is_authenticated') and st.session_state.get('hf_token'):
206
- st.markdown("---")
207
  st.header("Model Selection")
208
  st.session_state['selected_models'] = st.multiselect(
209
  "Choose AI Models",
@@ -226,8 +266,6 @@ def main():
226
  - **Playground v2.5**: Advanced model with high customization
227
  """)
228
 
229
- # Only show the main interface if authenticated
230
- if st.session_state.get('is_authenticated') and st.session_state.get('hf_token'):
231
  st.markdown("### Select a prompt style to generate artwork:")
232
 
233
  prompt_emojis = {
@@ -278,23 +316,25 @@ def main():
278
  st.markdown("---")
279
  st.markdown("### Generated Artwork:")
280
 
 
281
  for key in st.session_state:
282
- if key.startswith('selected_prompt_'):
283
  idx = key.split('_')[-1]
284
- images_key = f'generated_images_{idx}'
285
 
286
- if images_key in st.session_state:
287
- st.write("Prompt:", st.session_state[key])
288
 
289
- cols = st.columns(len(st.session_state[images_key]))
290
-
291
- for col, (model_name, result) in zip(cols, st.session_state[images_key]):
292
  with col:
293
  st.markdown(f"**{model_name}**")
294
  if isinstance(result, str) and result.startswith("Error"):
295
  st.error(result)
296
- else:
297
  st.image(result, use_container_width=True)
 
 
298
  else:
299
  st.info("Please login with your HuggingFace account to use the app")
300
 
 
7
  import io
8
  import requests
9
  from huggingface_hub import HfApi, login
10
+ from pathlib import Path
11
+ import json
12
 
13
+ def init_session_state():
14
+ """Initialize session state variables"""
15
+ if 'hf_token' not in st.session_state:
16
+ st.session_state.hf_token = None
17
+ if 'is_authenticated' not in st.session_state:
18
+ st.session_state.is_authenticated = False
19
+
20
+ def save_token(token):
21
+ """Save token to session state"""
22
+ st.session_state.hf_token = token
23
+ st.session_state.is_authenticated = True
24
+
25
+ def authenticate_user():
26
+ """Handle user authentication with HuggingFace"""
27
+ st.sidebar.markdown("## 🔐 Authentication")
28
+
29
+ if st.session_state.is_authenticated:
30
+ st.sidebar.success("✓ Logged in to HuggingFace")
31
+ if st.sidebar.button("Logout"):
32
+ st.session_state.hf_token = None
33
+ st.session_state.is_authenticated = False
34
+ st.rerun()
35
+ else:
36
+ token = st.sidebar.text_input("Enter HuggingFace Token", type="password",
37
+ help="Get your token from https://huggingface.co/settings/tokens")
38
+ if st.sidebar.button("Login"):
39
+ if token:
40
+ try:
41
+ # Verify token is valid
42
+ api = HfApi(token=token)
43
+ api.whoami()
44
+ save_token(token)
45
+ st.sidebar.success("Successfully logged in!")
46
+ st.rerun()
47
+ except Exception as e:
48
+ st.sidebar.error(f"Authentication failed: {str(e)}")
49
+ else:
50
+ st.sidebar.error("Please enter your HuggingFace token")
51
 
52
  class ModelGenerator:
53
+ def __init__(self, token):
54
+ self.token = token
55
+
56
+ def generate_midjourney(self, prompt):
57
  try:
58
+ client = Client("mukaist/Midjourney", hf_token=self.token)
59
  result = client.predict(
60
  prompt=prompt,
61
  negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
 
74
  if isinstance(image_data, str):
75
  if image_data.startswith('http'):
76
  response = requests.get(image_data)
77
+ return ("Midjourney", Image.open(io.BytesIO(response.content)))
78
  else:
79
+ return ("Midjourney", Image.open(image_data))
80
  else:
81
+ return ("Midjourney", Image.open(io.BytesIO(image_data)))
82
+ return ("Midjourney", "Error: No image generated")
 
 
83
  except Exception as e:
84
  return ("Midjourney", f"Error: {str(e)}")
85
 
86
+ def generate_stable_cascade(self, prompt):
 
87
  try:
88
+ client = Client("multimodalart/stable-cascade", hf_token=self.token)
89
  result = client.predict(
90
  prompt=prompt,
91
  negative_prompt=prompt,
 
99
  num_images_per_prompt=1,
100
  api_name="/run"
101
  )
102
+ if isinstance(result, list) and len(result) > 0:
103
+ image_data = result[0]
104
+ if isinstance(image_data, str):
105
+ if image_data.startswith('http'):
106
+ response = requests.get(image_data)
107
+ return ("Stable Cascade", Image.open(io.BytesIO(response.content)))
108
+ else:
109
+ return ("Stable Cascade", Image.open(image_data))
110
+ else:
111
+ return ("Stable Cascade", Image.open(io.BytesIO(image_data)))
112
+ return ("Stable Cascade", "Error: No image generated")
113
  except Exception as e:
114
  return ("Stable Cascade", f"Error: {str(e)}")
115
 
116
+ def generate_stable_diffusion_3(self, prompt):
 
117
  try:
118
+ client = Client("stabilityai/stable-diffusion-3-medium", hf_token=self.token)
119
  result = client.predict(
120
  prompt=prompt,
121
  negative_prompt=prompt,
 
127
  num_inference_steps=28,
128
  api_name="/infer"
129
  )
130
+ if isinstance(result, (str, bytes)):
131
+ return ("SD 3 Medium", Image.open(io.BytesIO(result) if isinstance(result, bytes) else result))
132
+ return ("SD 3 Medium", "Error: Unexpected result format")
133
  except Exception as e:
134
  return ("SD 3 Medium", f"Error: {str(e)}")
135
 
136
+ def generate_stable_diffusion_35(self, prompt):
 
137
  try:
138
+ client = Client("stabilityai/stable-diffusion-3.5-large", hf_token=self.token)
139
  result = client.predict(
140
  prompt=prompt,
141
  negative_prompt=prompt,
 
147
  num_inference_steps=40,
148
  api_name="/infer"
149
  )
150
+ if isinstance(result, (str, bytes)):
151
+ return ("SD 3.5 Large", Image.open(io.BytesIO(result) if isinstance(result, bytes) else result))
152
+ return ("SD 3.5 Large", "Error: Unexpected result format")
153
  except Exception as e:
154
  return ("SD 3.5 Large", f"Error: {str(e)}")
155
 
156
+ def generate_playground_v2_5(self, prompt):
 
157
  try:
158
+ client = Client("https://playgroundai-playground-v2-5.hf.space/--replicas/ji5gy/", hf_token=self.token)
 
159
  result = client.predict(
160
  prompt,
161
  prompt, # negative prompt
 
167
  True, # randomize seed
168
  api_name="/run"
169
  )
170
+ if isinstance(result, tuple) and result[0] and len(result[0]) > 0:
171
+ image_data = result[0][0].get('image')
172
+ if image_data:
173
+ if isinstance(image_data, str):
174
+ if image_data.startswith('http'):
175
+ response = requests.get(image_data)
176
+ return ("Playground v2.5", Image.open(io.BytesIO(response.content)))
177
+ return ("Playground v2.5", Image.open(image_data))
178
+ return ("Playground v2.5", Image.open(io.BytesIO(image_data)))
179
  return ("Playground v2.5", "Error: No image generated")
180
  except Exception as e:
181
  return ("Playground v2.5", f"Error: {str(e)}")
182
 
183
+ def generate_images(prompt, selected_models, token):
 
 
 
 
184
  results = []
185
  with concurrent.futures.ThreadPoolExecutor() as executor:
186
  futures = []
187
+ generator = ModelGenerator(token)
188
  model_map = {
189
+ "Midjourney": generator.generate_midjourney,
190
+ "Stable Cascade": generator.generate_stable_cascade,
191
+ "SD 3 Medium": generator.generate_stable_diffusion_3,
192
+ "SD 3.5 Large": generator.generate_stable_diffusion_35,
193
+ "Playground v2.5": generator.generate_playground_v2_5
194
  }
195
 
196
  for model in selected_models:
 
198
  futures.append(executor.submit(model_map[model], prompt))
199
 
200
  for future in concurrent.futures.as_completed(futures):
201
+ try:
202
+ result = future.result()
203
+ if result:
204
+ results.append(result)
205
+ except Exception as e:
206
+ st.error(f"Error during image generation: {str(e)}")
207
 
208
  return results
209
 
210
  def handle_prompt_click(prompt_text, key):
211
+ if not st.session_state.is_authenticated:
212
  st.error("Please login with your HuggingFace account first!")
213
  return
214
 
 
 
215
  selected_models = st.session_state.get('selected_models', [])
216
 
217
  if not selected_models:
 
219
  return
220
 
221
  with st.spinner('Generating artwork...'):
222
+ results = generate_images(prompt_text, selected_models, st.session_state.hf_token)
223
+ if results:
224
+ st.session_state[f'generated_images_{key}'] = results
225
+ st.success("Artwork generated successfully!")
226
+
227
+ # Display images immediately
228
+ cols = st.columns(len(results))
229
+ for col, (model_name, result) in zip(cols, results):
230
+ with col:
231
+ st.markdown(f"**{model_name}**")
232
+ if isinstance(result, str) and result.startswith("Error"):
233
+ st.error(result)
234
+ elif isinstance(result, Image.Image):
235
+ st.image(result, use_container_width=True)
236
+ else:
237
+ st.error(f"Unexpected result type: {type(result)}")
238
 
239
  def main():
240
  st.title("🎨 Multi-Model Art Generator")
241
+
242
+ init_session_state()
243
+ authenticate_user()
244
 
245
+ if st.session_state.is_authenticated:
246
+ with st.sidebar:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  st.header("Model Selection")
248
  st.session_state['selected_models'] = st.multiselect(
249
  "Choose AI Models",
 
266
  - **Playground v2.5**: Advanced model with high customization
267
  """)
268
 
 
 
269
  st.markdown("### Select a prompt style to generate artwork:")
270
 
271
  prompt_emojis = {
 
316
  st.markdown("---")
317
  st.markdown("### Generated Artwork:")
318
 
319
+ # Display any previously generated images
320
  for key in st.session_state:
321
+ if key.startswith('generated_images_'):
322
  idx = key.split('_')[-1]
323
+ prompt_key = f'selected_prompt_{idx}'
324
 
325
+ if prompt_key in st.session_state:
326
+ st.write("Prompt:", st.session_state[prompt_key])
327
 
328
+ cols = st.columns(len(st.session_state[key]))
329
+ for col, (model_name, result) in zip(cols, st.session_state[key]):
 
330
  with col:
331
  st.markdown(f"**{model_name}**")
332
  if isinstance(result, str) and result.startswith("Error"):
333
  st.error(result)
334
+ elif isinstance(result, Image.Image):
335
  st.image(result, use_container_width=True)
336
+ else:
337
+ st.error(f"Unexpected result type: {type(result)}")
338
  else:
339
  st.info("Please login with your HuggingFace account to use the app")
340