patrickbdevaney commited on
Commit
8033423
·
verified ·
1 Parent(s): 519d719

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -46
app.py CHANGED
@@ -83,68 +83,41 @@ def initialize_diffusers():
83
  pipe = initialize_diffusers()
84
  print("Models and checkpoints preloaded.")
85
 
86
- def generate_description_prompt(subject, user_prompt, text_generator):
87
- prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
 
 
88
  try:
89
- generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
90
- generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
91
- return generated_description if generated_description else None
92
  except Exception as e:
93
- print(f"Error generating description for subject '{subject}': {e}")
94
  return None
95
 
96
- def parse_descriptions(text):
97
- descriptions = re.findall(r'\[([^\[\]]+)\]', text)
98
- descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
99
- return descriptions
100
-
101
  def format_descriptions(descriptions):
102
  formatted_descriptions = "\n".join(descriptions)
103
  return formatted_descriptions
104
 
105
  @spaces.GPU
106
  def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1): # Set max_iterations to 1
107
- descriptions = []
108
- description_queue = deque()
109
- iteration_count = 0
110
-
111
- seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
112
-
113
- for _ in range(2): # Perform two iterations
114
- while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
115
- available_subjects = [word for word in seed_words if word not in used_words]
116
- if not available_subjects:
117
- print("No more available subjects to use.")
118
- break
119
-
120
- subject = random.choice(available_subjects)
121
- generated_description = generate_description_prompt(subject, user_prompt, text_generator)
122
-
123
- if generated_description:
124
- clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
125
- description_queue.append({'subject': subject, 'description': clean_description})
126
-
127
- print(f"Generated description for subject '{subject}': {clean_description}")
128
-
129
- used_words.add(subject)
130
- seed_words.append(clean_description)
131
-
132
- parsed_descriptions = parse_descriptions(clean_description)
133
- parsed_descriptions_queue.extend(parsed_descriptions)
134
-
135
- iteration_count += 1
136
-
137
- return list(parsed_descriptions_queue)
138
 
139
  @spaces.GPU(duration=120)
140
  def generate_images(parsed_descriptions, max_iterations=3): # Set max_iterations to 3
141
- # Limit the number of descriptions passed to the image generator to 2
142
  if len(parsed_descriptions) > MAX_IMAGES:
143
  parsed_descriptions = parsed_descriptions[:MAX_IMAGES]
144
 
145
  images = []
146
  for prompt in parsed_descriptions:
147
- images.extend(pipe(prompt, num_inference_steps=max_iterations, height=512, width=512).images) # Set resolution to 512 x 512
 
 
 
148
 
149
  return images
150
 
@@ -161,10 +134,10 @@ if __name__ == '__main__':
161
 
162
  interface = gr.Interface(
163
  fn=generate_and_display,
164
- inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
165
  outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
166
  live=False, # Set live to False
167
  allow_flagging='never' # Disable flagging
168
  )
169
 
170
- interface.launch(share=True)
 
83
  pipe = initialize_diffusers()
84
  print("Models and checkpoints preloaded.")
85
 
86
+ def generate_description_prompt(user_prompt, text_generator):
87
+ injected_prompt = f"write three concise descriptions enclosed in brackets like [ <description> ] less than 100 words each of {user_prompt}. "
88
+ max_length = 110 # Set max token length to 110
89
+
90
  try:
91
+ generated_text = text_generator(injected_prompt, max_length=max_length, num_return_sequences=1, truncation=True)[0]['generated_text']
92
+ generated_descriptions = re.findall(r'\[([^\[\]]+)\]', generated_text) # Extract descriptions enclosed in brackets
93
+ return generated_descriptions if generated_descriptions else None
94
  except Exception as e:
95
+ print(f"Error generating descriptions: {e}")
96
  return None
97
 
 
 
 
 
 
98
  def format_descriptions(descriptions):
99
  formatted_descriptions = "\n".join(descriptions)
100
  return formatted_descriptions
101
 
102
  @spaces.GPU
103
  def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1): # Set max_iterations to 1
104
+ descriptions = generate_description_prompt(user_prompt, text_generator)
105
+ if descriptions:
106
+ parsed_descriptions_queue.extend(descriptions)
107
+ return list(parsed_descriptions_queue)[:MAX_IMAGES]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  @spaces.GPU(duration=120)
110
  def generate_images(parsed_descriptions, max_iterations=3): # Set max_iterations to 3
111
+ # Limit the number of descriptions passed to the image generator to MAX_IMAGES (3)
112
  if len(parsed_descriptions) > MAX_IMAGES:
113
  parsed_descriptions = parsed_descriptions[:MAX_IMAGES]
114
 
115
  images = []
116
  for prompt in parsed_descriptions:
117
+ try:
118
+ images.extend(pipe(prompt, num_inference_steps=4, height=1024, width=1024).images) # Set resolution to 512 x 512
119
+ except Exception as e:
120
+ print(f"Error generating image for prompt '{prompt}': {e}")
121
 
122
  return images
123
 
 
134
 
135
  interface = gr.Interface(
136
  fn=generate_and_display,
137
+ inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter example in quotes, e.g., "cat", "dog", "sunset"...')],
138
  outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
139
  live=False, # Set live to False
140
  allow_flagging='never' # Disable flagging
141
  )
142
 
143
+ interface.launch(share=True)