patrickbdevaney commited on
Commit
3aaa901
·
verified ·
1 Parent(s): c23176c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import spaces
2
  import os
3
  import pandas as pd
4
  import torch
@@ -15,8 +15,6 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokeniz
15
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
16
  import gradio as gr
17
  from accelerate import Accelerator
18
- import huggingface_hub # Ensure this import is correct
19
-
20
 
21
  # Instantiate the Accelerator
22
  accelerator = Accelerator()
@@ -89,7 +87,7 @@ paused = False
89
  parsed_descriptions_queue = deque()
90
 
91
  @spaces.GPU
92
- def generate_and_store_descriptions(user_prompt, batch_size=100, max_iterations=50):
93
  global paused
94
  descriptions = []
95
  description_queue = deque()
@@ -103,6 +101,9 @@ def generate_and_store_descriptions(user_prompt, batch_size=100, max_iterations=
103
  text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
104
  print("Text generation pipeline initialized with 16-bit precision.")
105
 
 
 
 
106
  while iteration_count < max_iterations:
107
  if paused:
108
  break
@@ -132,20 +133,18 @@ def generate_and_store_descriptions(user_prompt, batch_size=100, max_iterations=
132
  if iteration_count % 3 == 0:
133
  parsed_descriptions = parse_descriptions(clean_description)
134
  parsed_descriptions_queue.extend(parsed_descriptions)
135
- # Return the parsed descriptions to update the Gradio UI
136
- return list(parsed_descriptions_queue)
137
 
138
  iteration_count += 1
139
 
140
  return list(parsed_descriptions_queue)
141
 
142
  @spaces.GPU(duration=120)
143
- def generate_images_from_parsed_descriptions():
144
  # If there are fewer than 13 descriptions, use whatever is available
145
- if len(parsed_descriptions_queue) < 13:
146
- prompts = list(parsed_descriptions_queue)
147
  else:
148
- prompts = [parsed_descriptions_queue.popleft() for _ in range(13)]
149
 
150
  # Generate images from the parsed descriptions
151
  images = []
@@ -155,16 +154,15 @@ def generate_images_from_parsed_descriptions():
155
  return images
156
 
157
  # Create Gradio Interface
158
- description_interface = gr.Interface(
159
- fn=generate_and_store_descriptions,
160
- inputs=gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."),
161
- outputs="json"
162
- )
163
 
164
- image_interface = gr.Interface(
165
- fn=generate_images_from_parsed_descriptions,
166
- inputs=None,
167
  outputs=gr.Gallery()
168
  )
169
 
170
- gr.TabbedInterface([description_interface, image_interface], ["Generate Descriptions", "Generate Images"]).launch()
 
1
+ import spaces # Import spaces at the very beginning
2
  import os
3
  import pandas as pd
4
  import torch
 
15
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
16
  import gradio as gr
17
  from accelerate import Accelerator
 
 
18
 
19
  # Instantiate the Accelerator
20
  accelerator = Accelerator()
 
87
  parsed_descriptions_queue = deque()
88
 
89
  @spaces.GPU
90
+ def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
91
  global paused
92
  descriptions = []
93
  description_queue = deque()
 
101
  text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
102
  print("Text generation pipeline initialized with 16-bit precision.")
103
 
104
+ # Populate the seed_words array with user input
105
+ seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
106
+
107
  while iteration_count < max_iterations:
108
  if paused:
109
  break
 
133
  if iteration_count % 3 == 0:
134
  parsed_descriptions = parse_descriptions(clean_description)
135
  parsed_descriptions_queue.extend(parsed_descriptions)
 
 
136
 
137
  iteration_count += 1
138
 
139
  return list(parsed_descriptions_queue)
140
 
141
  @spaces.GPU(duration=120)
142
+ def generate_images(parsed_descriptions):
143
  # If there are fewer than 13 descriptions, use whatever is available
144
+ if len(parsed_descriptions) < 13:
145
+ prompts = parsed_descriptions
146
  else:
147
+ prompts = [parsed_descriptions.pop(0) for _ in range(13)]
148
 
149
  # Generate images from the parsed descriptions
150
  images = []
 
154
  return images
155
 
156
  # Create Gradio Interface
157
+ def combined_function(user_prompt, seed_words_input):
158
+ parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
159
+ images = generate_images(parsed_descriptions)
160
+ return images
 
161
 
162
+ interface = gr.Interface(
163
+ fn=combined_function,
164
+ inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
165
  outputs=gr.Gallery()
166
  )
167
 
168
+ interface.launch()