awacke1 commited on
Commit
4dab44a
ยท
verified ยท
1 Parent(s): 0e426ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -130
app.py CHANGED
@@ -6,17 +6,56 @@ import os
6
  from PIL import Image
7
  import io
8
  import requests
 
 
 
9
 
10
- # Get token from environment variable
11
- HF_TOKEN = os.getenv('ArtToken')
12
- if not HF_TOKEN:
13
- raise ValueError("Please set the 'ArtToken' environment variable with your Hugging Face token")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class ModelGenerator:
16
- @staticmethod
17
- def generate_midjourney(prompt):
 
 
18
  try:
19
- client = Client("mukaist/Midjourney", hf_token=HF_TOKEN)
20
  result = client.predict(
21
  prompt=prompt,
22
  negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
@@ -30,20 +69,15 @@ class ModelGenerator:
30
  api_name="/run"
31
  )
32
 
33
- # Handle the result based on its type
34
  if isinstance(result, list) and len(result) > 0:
35
- # If result is a list of file paths or URLs
36
  image_data = result[0]
37
  if isinstance(image_data, str):
38
  if image_data.startswith('http'):
39
- # If it's a URL, download the image
40
  response = requests.get(image_data)
41
  image = Image.open(io.BytesIO(response.content))
42
  else:
43
- # If it's a file path
44
  image = Image.open(image_data)
45
  else:
46
- # If it's already image data
47
  image = Image.open(io.BytesIO(image_data))
48
  return ("Midjourney", image)
49
  else:
@@ -51,10 +85,9 @@ class ModelGenerator:
51
  except Exception as e:
52
  return ("Midjourney", f"Error: {str(e)}")
53
 
54
- @staticmethod
55
- def generate_stable_cascade(prompt):
56
  try:
57
- client = Client("multimodalart/stable-cascade", hf_token=HF_TOKEN)
58
  result = client.predict(
59
  prompt=prompt,
60
  negative_prompt=prompt,
@@ -72,10 +105,9 @@ class ModelGenerator:
72
  except Exception as e:
73
  return ("Stable Cascade", f"Error: {str(e)}")
74
 
75
- @staticmethod
76
- def generate_stable_diffusion_3(prompt):
77
  try:
78
- client = Client("stabilityai/stable-diffusion-3-medium", hf_token=HF_TOKEN)
79
  result = client.predict(
80
  prompt=prompt,
81
  negative_prompt=prompt,
@@ -91,10 +123,9 @@ class ModelGenerator:
91
  except Exception as e:
92
  return ("SD 3 Medium", f"Error: {str(e)}")
93
 
94
- @staticmethod
95
- def generate_stable_diffusion_35(prompt):
96
  try:
97
- client = Client("stabilityai/stable-diffusion-3.5-large", hf_token=HF_TOKEN)
98
  result = client.predict(
99
  prompt=prompt,
100
  negative_prompt=prompt,
@@ -110,10 +141,9 @@ class ModelGenerator:
110
  except Exception as e:
111
  return ("SD 3.5 Large", f"Error: {str(e)}")
112
 
113
- @staticmethod
114
- def generate_playground_v2_5(prompt):
115
  try:
116
- client = Client("https://playgroundai-playground-v2-5.hf.space/--replicas/ji5gy/", hf_token=HF_TOKEN)
117
  result = client.predict(
118
  prompt,
119
  prompt, # negative prompt
@@ -125,23 +155,23 @@ class ModelGenerator:
125
  True, # randomize seed
126
  api_name="/run"
127
  )
128
- # Result is a tuple (gallery, seed), we want just the first image from gallery
129
  if result and isinstance(result, tuple) and result[0]:
130
  return ("Playground v2.5", result[0][0]['image'])
131
  return ("Playground v2.5", "Error: No image generated")
132
  except Exception as e:
133
  return ("Playground v2.5", f"Error: {str(e)}")
134
 
135
- def generate_images(prompt, selected_models):
136
  results = []
137
  with concurrent.futures.ThreadPoolExecutor() as executor:
138
  futures = []
 
139
  model_map = {
140
- "Midjourney": ModelGenerator.generate_midjourney,
141
- "Stable Cascade": ModelGenerator.generate_stable_cascade,
142
- "SD 3 Medium": ModelGenerator.generate_stable_diffusion_3,
143
- "SD 3.5 Large": ModelGenerator.generate_stable_diffusion_35,
144
- "Playground v2.5": ModelGenerator.generate_playground_v2_5
145
  }
146
 
147
  for model in selected_models:
@@ -154,12 +184,10 @@ def generate_images(prompt, selected_models):
154
  return results
155
 
156
  def handle_prompt_click(prompt_text, key):
157
- if not HF_TOKEN:
158
- st.error("Environment variable 'ArtToken' is not set!")
159
  return
160
 
161
- st.session_state[f'selected_prompt_{key}'] = prompt_text
162
-
163
  selected_models = st.session_state.get('selected_models', [])
164
 
165
  if not selected_models:
@@ -167,113 +195,113 @@ def handle_prompt_click(prompt_text, key):
167
  return
168
 
169
  with st.spinner('Generating artwork...'):
170
- results = generate_images(prompt_text, selected_models)
171
  st.session_state[f'generated_images_{key}'] = results
172
  st.success("Artwork generated successfully!")
173
 
174
  def main():
175
  st.title("๐ŸŽจ Multi-Model Art Generator")
 
 
 
 
 
 
176
 
177
- with st.sidebar:
178
- st.header("Configuration")
179
-
180
- # Show token status
181
- if HF_TOKEN:
182
- st.success("โœ“ ArtToken loaded from environment")
183
- else:
184
- st.error("โš  ArtToken not found in environment")
185
-
186
- st.markdown("---")
187
- st.header("Model Selection")
188
- st.session_state['selected_models'] = st.multiselect(
189
- "Choose AI Models",
190
- ["Midjourney", "Stable Cascade", "SD 3 Medium", "SD 3.5 Large", "Playground v2.5"],
191
- default=["Midjourney"]
192
- )
193
-
194
- st.markdown("---")
195
- st.markdown("### Selected Models:")
196
- for model in st.session_state['selected_models']:
197
- st.write(f"โœ“ {model}")
198
-
199
- st.markdown("---")
200
- st.markdown("### Model Information:")
201
- st.markdown("""
202
- - **Midjourney**: Best for artistic and creative imagery
203
- - **Stable Cascade**: New architecture with high detail
204
- - **SD 3 Medium**: Fast and efficient generation
205
- - **SD 3.5 Large**: Highest quality, slower generation
206
- - **Playground v2.5**: Advanced model with high customization
207
- """)
208
 
209
- st.markdown("### Select a prompt style to generate artwork:")
210
 
211
- prompt_emojis = {
212
- "AIart/AIArtistCommunity": "๐Ÿค–",
213
- "Black & White": "โšซโšช",
214
- "Black & Yellow": "โšซ๐Ÿ’›",
215
- "Blindfold": "๐Ÿ™ˆ",
216
- "Break": "๐Ÿ’”",
217
- "Broken": "๐Ÿ”จ",
218
- "Christmas Celebrations art": "๐ŸŽ„",
219
- "Colorful Art": "๐ŸŽจ",
220
- "Crimson art": "๐Ÿ”ด",
221
- "Eyes Art": "๐Ÿ‘๏ธ",
222
- "Going out with Style": "๐Ÿ’ƒ",
223
- "Hooded Girl": "๐Ÿงฅ",
224
- "Lips": "๐Ÿ‘„",
225
- "MAEKHLONG": "๐Ÿฎ",
226
- "Mermaid": "๐Ÿงœโ€โ™€๏ธ",
227
- "Morning Sunshine": "๐ŸŒ…",
228
- "Music Art": "๐ŸŽต",
229
- "Owl": "๐Ÿฆ‰",
230
- "Pink": "๐Ÿ’—",
231
- "Purple": "๐Ÿ’œ",
232
- "Rain": "๐ŸŒง๏ธ",
233
- "Red Moon": "๐ŸŒ‘",
234
- "Rose": "๐ŸŒน",
235
- "Snow": "โ„๏ธ",
236
- "Spacesuit Girl": "๐Ÿ‘ฉโ€๐Ÿš€",
237
- "Steampunk": "โš™๏ธ",
238
- "Succubus": "๐Ÿ˜ˆ",
239
- "Sunlight": "โ˜€๏ธ",
240
- "Weird art": "๐ŸŽญ",
241
- "White Hair": "๐Ÿ‘ฑโ€โ™€๏ธ",
242
- "Wings art": "๐Ÿ‘ผ",
243
- "Woman with Sword": "โš”๏ธ"
244
- }
245
 
246
- col1, col2, col3 = st.columns(3)
247
-
248
- for idx, (prompt, emoji) in enumerate(prompt_emojis.items()):
249
- full_prompt = f"QT {prompt}"
250
- col = [col1, col2, col3][idx % 3]
251
 
252
- with col:
253
- if st.button(f"{emoji} {prompt}", key=f"btn_{idx}"):
254
- handle_prompt_click(full_prompt, idx)
255
-
256
- st.markdown("---")
257
- st.markdown("### Generated Artwork:")
258
-
259
- for key in st.session_state:
260
- if key.startswith('selected_prompt_'):
261
- idx = key.split('_')[-1]
262
- images_key = f'generated_images_{idx}'
263
 
264
- if images_key in st.session_state:
265
- st.write("Prompt:", st.session_state[key])
266
-
267
- cols = st.columns(len(st.session_state[images_key]))
 
 
 
 
 
 
 
268
 
269
- for col, (model_name, result) in zip(cols, st.session_state[images_key]):
270
- with col:
271
- st.markdown(f"**{model_name}**")
272
- if isinstance(result, str) and result.startswith("Error"):
273
- st.error(result)
274
- else:
275
- # Updated to use use_container_width instead of use_column_width
276
- st.image(result, use_container_width=True)
 
 
 
 
 
 
277
 
278
  if __name__ == "__main__":
279
  main()
 
6
  from PIL import Image
7
  import io
8
  import requests
9
+ from huggingface_hub import HfApi, login
10
+ from pathlib import Path
11
+ import json
12
 
13
+ def init_session_state():
14
+ """Initialize session state variables"""
15
+ if 'hf_token' not in st.session_state:
16
+ st.session_state.hf_token = None
17
+ if 'is_authenticated' not in st.session_state:
18
+ st.session_state.is_authenticated = False
19
+
20
+ def save_token(token):
21
+ """Save token to session state"""
22
+ st.session_state.hf_token = token
23
+ st.session_state.is_authenticated = True
24
+
25
+ def authenticate_user():
26
+ """Handle user authentication with HuggingFace"""
27
+ st.sidebar.markdown("## ๐Ÿ” Authentication")
28
+
29
+ if st.session_state.is_authenticated:
30
+ st.sidebar.success("โœ“ Logged in to HuggingFace")
31
+ if st.sidebar.button("Logout"):
32
+ st.session_state.hf_token = None
33
+ st.session_state.is_authenticated = False
34
+ st.experimental_rerun()
35
+ else:
36
+ token = st.sidebar.text_input("Enter HuggingFace Token", type="password",
37
+ help="Get your token from https://huggingface.co/settings/tokens")
38
+ if st.sidebar.button("Login"):
39
+ if token:
40
+ try:
41
+ # Verify token is valid
42
+ api = HfApi(token=token)
43
+ api.whoami()
44
+ save_token(token)
45
+ st.sidebar.success("Successfully logged in!")
46
+ st.experimental_rerun()
47
+ except Exception as e:
48
+ st.sidebar.error(f"Authentication failed: {str(e)}")
49
+ else:
50
+ st.sidebar.error("Please enter your HuggingFace token")
51
 
52
  class ModelGenerator:
53
+ def __init__(self, token):
54
+ self.token = token
55
+
56
+ def generate_midjourney(self, prompt):
57
  try:
58
+ client = Client("mukaist/Midjourney", hf_token=self.token)
59
  result = client.predict(
60
  prompt=prompt,
61
  negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
 
69
  api_name="/run"
70
  )
71
 
 
72
  if isinstance(result, list) and len(result) > 0:
 
73
  image_data = result[0]
74
  if isinstance(image_data, str):
75
  if image_data.startswith('http'):
 
76
  response = requests.get(image_data)
77
  image = Image.open(io.BytesIO(response.content))
78
  else:
 
79
  image = Image.open(image_data)
80
  else:
 
81
  image = Image.open(io.BytesIO(image_data))
82
  return ("Midjourney", image)
83
  else:
 
85
  except Exception as e:
86
  return ("Midjourney", f"Error: {str(e)}")
87
 
88
+ def generate_stable_cascade(self, prompt):
 
89
  try:
90
+ client = Client("multimodalart/stable-cascade", hf_token=self.token)
91
  result = client.predict(
92
  prompt=prompt,
93
  negative_prompt=prompt,
 
105
  except Exception as e:
106
  return ("Stable Cascade", f"Error: {str(e)}")
107
 
108
+ def generate_stable_diffusion_3(self, prompt):
 
109
  try:
110
+ client = Client("stabilityai/stable-diffusion-3-medium", hf_token=self.token)
111
  result = client.predict(
112
  prompt=prompt,
113
  negative_prompt=prompt,
 
123
  except Exception as e:
124
  return ("SD 3 Medium", f"Error: {str(e)}")
125
 
126
+ def generate_stable_diffusion_35(self, prompt):
 
127
  try:
128
+ client = Client("stabilityai/stable-diffusion-3.5-large", hf_token=self.token)
129
  result = client.predict(
130
  prompt=prompt,
131
  negative_prompt=prompt,
 
141
  except Exception as e:
142
  return ("SD 3.5 Large", f"Error: {str(e)}")
143
 
144
+ def generate_playground_v2_5(self, prompt):
 
145
  try:
146
+ client = Client("https://playgroundai-playground-v2-5.hf.space/--replicas/ji5gy/", hf_token=self.token)
147
  result = client.predict(
148
  prompt,
149
  prompt, # negative prompt
 
155
  True, # randomize seed
156
  api_name="/run"
157
  )
 
158
  if result and isinstance(result, tuple) and result[0]:
159
  return ("Playground v2.5", result[0][0]['image'])
160
  return ("Playground v2.5", "Error: No image generated")
161
  except Exception as e:
162
  return ("Playground v2.5", f"Error: {str(e)}")
163
 
164
+ def generate_images(prompt, selected_models, token):
165
  results = []
166
  with concurrent.futures.ThreadPoolExecutor() as executor:
167
  futures = []
168
+ generator = ModelGenerator(token)
169
  model_map = {
170
+ "Midjourney": generator.generate_midjourney,
171
+ "Stable Cascade": generator.generate_stable_cascade,
172
+ "SD 3 Medium": generator.generate_stable_diffusion_3,
173
+ "SD 3.5 Large": generator.generate_stable_diffusion_35,
174
+ "Playground v2.5": generator.generate_playground_v2_5
175
  }
176
 
177
  for model in selected_models:
 
184
  return results
185
 
186
  def handle_prompt_click(prompt_text, key):
187
+ if not st.session_state.is_authenticated:
188
+ st.error("Please login with your HuggingFace account first!")
189
  return
190
 
 
 
191
  selected_models = st.session_state.get('selected_models', [])
192
 
193
  if not selected_models:
 
195
  return
196
 
197
  with st.spinner('Generating artwork...'):
198
+ results = generate_images(prompt_text, selected_models, st.session_state.hf_token)
199
  st.session_state[f'generated_images_{key}'] = results
200
  st.success("Artwork generated successfully!")
201
 
202
  def main():
203
  st.title("๐ŸŽจ Multi-Model Art Generator")
204
+
205
+ # Initialize session state
206
+ init_session_state()
207
+
208
+ # Handle authentication
209
+ authenticate_user()
210
 
211
+ # Only show the main interface if authenticated
212
+ if st.session_state.is_authenticated:
213
+ with st.sidebar:
214
+ st.header("Model Selection")
215
+ st.session_state['selected_models'] = st.multiselect(
216
+ "Choose AI Models",
217
+ ["Midjourney", "Stable Cascade", "SD 3 Medium", "SD 3.5 Large", "Playground v2.5"],
218
+ default=["Midjourney"]
219
+ )
220
+
221
+ st.markdown("---")
222
+ st.markdown("### Selected Models:")
223
+ for model in st.session_state['selected_models']:
224
+ st.write(f"โœ“ {model}")
225
+
226
+ st.markdown("---")
227
+ st.markdown("### Model Information:")
228
+ st.markdown("""
229
+ - **Midjourney**: Best for artistic and creative imagery
230
+ - **Stable Cascade**: New architecture with high detail
231
+ - **SD 3 Medium**: Fast and efficient generation
232
+ - **SD 3.5 Large**: Highest quality, slower generation
233
+ - **Playground v2.5**: Advanced model with high customization
234
+ """)
 
 
 
 
 
 
 
235
 
236
+ st.markdown("### Select a prompt style to generate artwork:")
237
 
238
+ prompt_emojis = {
239
+ "AIart/AIArtistCommunity": "๐Ÿค–",
240
+ "Black & White": "โšซโšช",
241
+ "Black & Yellow": "โšซ๐Ÿ’›",
242
+ "Blindfold": "๐Ÿ™ˆ",
243
+ "Break": "๐Ÿ’”",
244
+ "Broken": "๐Ÿ”จ",
245
+ "Christmas Celebrations art": "๐ŸŽ„",
246
+ "Colorful Art": "๐ŸŽจ",
247
+ "Crimson art": "๐Ÿ”ด",
248
+ "Eyes Art": "๐Ÿ‘๏ธ",
249
+ "Going out with Style": "๐Ÿ’ƒ",
250
+ "Hooded Girl": "๐Ÿงฅ",
251
+ "Lips": "๐Ÿ‘„",
252
+ "MAEKHLONG": "๐Ÿฎ",
253
+ "Mermaid": "๐Ÿงœโ€โ™€๏ธ",
254
+ "Morning Sunshine": "๐ŸŒ…",
255
+ "Music Art": "๐ŸŽต",
256
+ "Owl": "๐Ÿฆ‰",
257
+ "Pink": "๐Ÿ’—",
258
+ "Purple": "๐Ÿ’œ",
259
+ "Rain": "๐ŸŒง๏ธ",
260
+ "Red Moon": "๐ŸŒ‘",
261
+ "Rose": "๐ŸŒน",
262
+ "Snow": "โ„๏ธ",
263
+ "Spacesuit Girl": "๐Ÿ‘ฉโ€๐Ÿš€",
264
+ "Steampunk": "โš™๏ธ",
265
+ "Succubus": "๐Ÿ˜ˆ",
266
+ "Sunlight": "โ˜€๏ธ",
267
+ "Weird art": "๐ŸŽญ",
268
+ "White Hair": "๐Ÿ‘ฑโ€โ™€๏ธ",
269
+ "Wings art": "๐Ÿ‘ผ",
270
+ "Woman with Sword": "โš”๏ธ"
271
+ }
272
 
273
+ col1, col2, col3 = st.columns(3)
 
 
 
 
274
 
275
+ for idx, (prompt, emoji) in enumerate(prompt_emojis.items()):
276
+ full_prompt = f"QT {prompt}"
277
+ col = [col1, col2, col3][idx % 3]
 
 
 
 
 
 
 
 
278
 
279
+ with col:
280
+ if st.button(f"{emoji} {prompt}", key=f"btn_{idx}"):
281
+ handle_prompt_click(full_prompt, idx)
282
+
283
+ st.markdown("---")
284
+ st.markdown("### Generated Artwork:")
285
+
286
+ for key in st.session_state:
287
+ if key.startswith('selected_prompt_'):
288
+ idx = key.split('_')[-1]
289
+ images_key = f'generated_images_{idx}'
290
 
291
+ if images_key in st.session_state:
292
+ st.write("Prompt:", st.session_state[key])
293
+
294
+ cols = st.columns(len(st.session_state[images_key]))
295
+
296
+ for col, (model_name, result) in zip(cols, st.session_state[images_key]):
297
+ with col:
298
+ st.markdown(f"**{model_name}**")
299
+ if isinstance(result, str) and result.startswith("Error"):
300
+ st.error(result)
301
+ else:
302
+ st.image(result, use_container_width=True)
303
+ else:
304
+ st.info("Please login with your HuggingFace account to use the app")
305
 
306
  if __name__ == "__main__":
307
  main()