satyanayak commited on
Commit
6269d98
·
1 Parent(s): 6df8da7

changed the concepts to 768-embeddings

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -10,13 +10,13 @@ import os
10
  model_id = "CompVis/stable-diffusion-v1-4"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- # List of concept embeddings to use
14
  concepts = [
15
- "sd-concepts-library/sword-lily-flowers102",
16
- "sd-concepts-library/azalea-flowers102",
17
- "sd-concepts-library/samurai-jack",
18
- "sd-concepts-library/wu-shi-art",
19
- "sd-concepts-library/wu-shi"
20
  ]
21
 
22
  def download_concept_embedding(concept_name):
@@ -50,6 +50,7 @@ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
50
 
51
  def generate_images(prompt):
52
  images = []
 
53
 
54
  # Load base pipeline
55
  pipe = StableDiffusionPipeline.from_pretrained(
@@ -58,12 +59,13 @@ def generate_images(prompt):
58
  ).to(device)
59
 
60
  for concept in concepts:
61
- # Download and load concept embedding
62
- embed_path = download_concept_embedding(concept)
63
- if embed_path is None:
64
- continue
65
-
66
  try:
 
 
 
 
 
 
67
  token = load_learned_embed_in_clip(
68
  embed_path,
69
  pipe.text_encoder,
@@ -81,7 +83,7 @@ def generate_images(prompt):
81
  with autocast(device):
82
  image = pipe(
83
  concept_prompt,
84
- num_inference_steps=50,
85
  generator=generator,
86
  guidance_scale=7.5
87
  ).images[0]
@@ -94,17 +96,24 @@ def generate_images(prompt):
94
 
95
  except Exception as e:
96
  print(f"Error processing concept {concept}: {str(e)}")
 
97
  continue
98
 
99
- return images if images else [None] * 5
 
 
 
 
 
 
100
 
101
  # Create Gradio interface
102
  iface = gr.Interface(
103
  fn=generate_images,
104
  inputs=gr.Textbox(label="Enter your prompt"),
105
- outputs=[gr.Image() for _ in range(5)],
106
  title="Multi-Concept Stable Diffusion Generator",
107
- description="Generate images using 5 different concepts from the SD Concepts Library"
108
  )
109
 
110
  # Launch the app
 
10
  model_id = "CompVis/stable-diffusion-v1-4"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # List of concept embeddings compatible with SD v1.4
14
  concepts = [
15
+ "sd-concepts-library/cat-toy",
16
+ "sd-concepts-library/disco-diffusion-style",
17
+ "sd-concepts-library/modern-disney-style",
18
+ "sd-concepts-library/charliebo-artstyle",
19
+ "sd-concepts-library/redshift-render-style"
20
  ]
21
 
22
  def download_concept_embedding(concept_name):
 
50
 
51
  def generate_images(prompt):
52
  images = []
53
+ failed_concepts = []
54
 
55
  # Load base pipeline
56
  pipe = StableDiffusionPipeline.from_pretrained(
 
59
  ).to(device)
60
 
61
  for concept in concepts:
 
 
 
 
 
62
  try:
63
+ # Download and load concept embedding
64
+ embed_path = download_concept_embedding(concept)
65
+ if embed_path is None:
66
+ failed_concepts.append(concept)
67
+ continue
68
+
69
  token = load_learned_embed_in_clip(
70
  embed_path,
71
  pipe.text_encoder,
 
83
  with autocast(device):
84
  image = pipe(
85
  concept_prompt,
86
+ num_inference_steps=40,
87
  generator=generator,
88
  guidance_scale=7.5
89
  ).images[0]
 
96
 
97
  except Exception as e:
98
  print(f"Error processing concept {concept}: {str(e)}")
99
+ failed_concepts.append(concept)
100
  continue
101
 
102
+ if failed_concepts:
103
+ print(f"Failed to process concepts: {', '.join(failed_concepts)}")
104
+
105
+ # Return available images, pad with None if some failed
106
+ while len(images) < 5:
107
+ images.append(None)
108
+ return images[:5]
109
 
110
  # Create Gradio interface
111
  iface = gr.Interface(
112
  fn=generate_images,
113
  inputs=gr.Textbox(label="Enter your prompt"),
114
+ outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)],
115
  title="Multi-Concept Stable Diffusion Generator",
116
+ description="Generate images using 5 different artistic concepts from the SD Concepts Library"
117
  )
118
 
119
  # Launch the app