patrickbdevaney commited on
Commit
5f554bd
·
verified ·
1 Parent(s): 3aaa901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -168
app.py CHANGED
@@ -1,168 +1,177 @@
1
- import spaces # Import spaces at the very beginning
2
- import os
3
- import pandas as pd
4
- import torch
5
- import gc
6
- import re
7
- import random
8
- from tqdm.auto import tqdm
9
- from collections import deque
10
- from optimum.quanto import freeze, qfloat8, quantize
11
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
12
- from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
13
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
14
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
15
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
16
- import gradio as gr
17
- from accelerate import Accelerator
18
-
19
- # Instantiate the Accelerator
20
- accelerator = Accelerator()
21
-
22
- dtype = torch.bfloat16
23
-
24
- # Set environment variables for local path
25
- os.environ['FLUX_DEV'] = '.'
26
- os.environ['AE'] = '.'
27
-
28
- bfl_repo = 'black-forest-labs/FLUX.1-schnell'
29
- revision = 'refs/pr/1'
30
-
31
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
32
- text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
33
- tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
34
- text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
35
- tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
36
- vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
37
- transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
38
-
39
- quantize(transformer, weights=qfloat8)
40
- freeze(transformer)
41
- quantize(text_encoder_2, weights=qfloat8)
42
- freeze(text_encoder_2)
43
-
44
- pipe = FluxPipeline(
45
- scheduler=scheduler,
46
- text_encoder=text_encoder,
47
- tokenizer=tokenizer,
48
- text_encoder_2=None,
49
- tokenizer_2=tokenizer_2,
50
- vae=vae,
51
- transformer=None,
52
- )
53
- pipe.text_encoder_2 = text_encoder_2
54
- pipe.transformer = transformer
55
- pipe.enable_model_cpu_offload()
56
-
57
- # Create a directory to save the generated images
58
- output_dir = 'generated_images'
59
- os.makedirs(output_dir, exist_ok=True)
60
-
61
- # Function to generate a detailed visual description prompt
62
- def generate_description_prompt(subject, user_prompt, text_generator):
63
- prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
64
- try:
65
- generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
66
- generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
67
- return generated_description if generated_description else None
68
- except Exception as e:
69
- print(f"Error generating description for subject '{subject}': {e}")
70
- return None
71
-
72
- # Function to parse descriptions from a given text
73
- def parse_descriptions(text):
74
- # Find all descriptions enclosed in brackets
75
- descriptions = re.findall(r'\[([^\[\]]+)\]', text)
76
- # Filter descriptions with at least 3 words
77
- descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
78
- return descriptions
79
-
80
- # Seed words pool
81
- seed_words = []
82
-
83
- used_words = set()
84
- paused = False
85
-
86
- # Queue to store parsed descriptions
87
- parsed_descriptions_queue = deque()
88
-
89
- @spaces.GPU
90
- def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
91
- global paused
92
- descriptions = []
93
- description_queue = deque()
94
- iteration_count = 0
95
-
96
- # Initialize the text generation pipeline with 16-bit precision
97
- print("Initializing the text generation pipeline with 16-bit precision...")
98
- model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
99
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
100
- tokenizer = AutoTokenizer.from_pretrained(model_name)
101
- text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
102
- print("Text generation pipeline initialized with 16-bit precision.")
103
-
104
- # Populate the seed_words array with user input
105
- seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
106
-
107
- while iteration_count < max_iterations:
108
- if paused:
109
- break
110
-
111
- # Select a subject that has not been used
112
- available_subjects = [word for word in seed_words if word not in used_words]
113
- if not available_subjects:
114
- print("No more available subjects to use.")
115
- break
116
-
117
- subject = random.choice(available_subjects)
118
- generated_description = generate_description_prompt(subject, user_prompt, text_generator)
119
-
120
- if generated_description:
121
- # Remove any offending symbols
122
- clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
123
- description_queue.append({'subject': subject, 'description': clean_description})
124
-
125
- # Print the generated description to the command line
126
- print(f"Generated description for subject '{subject}': {clean_description}")
127
-
128
- # Update used words and seed words
129
- used_words.add(subject)
130
- seed_words.append(clean_description) # Add the generated description to the seed bank array
131
-
132
- # Parse and append descriptions every 3 iterations
133
- if iteration_count % 3 == 0:
134
- parsed_descriptions = parse_descriptions(clean_description)
135
- parsed_descriptions_queue.extend(parsed_descriptions)
136
-
137
- iteration_count += 1
138
-
139
- return list(parsed_descriptions_queue)
140
-
141
- @spaces.GPU(duration=120)
142
- def generate_images(parsed_descriptions):
143
- # If there are fewer than 13 descriptions, use whatever is available
144
- if len(parsed_descriptions) < 13:
145
- prompts = parsed_descriptions
146
- else:
147
- prompts = [parsed_descriptions.pop(0) for _ in range(13)]
148
-
149
- # Generate images from the parsed descriptions
150
- images = []
151
- for prompt in prompts:
152
- images.extend(pipe(prompt, num_images=1).images)
153
-
154
- return images
155
-
156
- # Create Gradio Interface
157
- def combined_function(user_prompt, seed_words_input):
158
- parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
159
- images = generate_images(parsed_descriptions)
160
- return images
161
-
162
- interface = gr.Interface(
163
- fn=combined_function,
164
- inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
165
- outputs=gr.Gallery()
166
- )
167
-
168
- interface.launch()
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+
3
+ if __name__ == '__main__':
4
+ multiprocessing.set_start_method('spawn')
5
+
6
+ import spaces # Import spaces at the very beginning
7
+ import os
8
+ import pandas as pd
9
+ import torch
10
+ import gc
11
+ import re
12
+ import random
13
+ from tqdm.auto import tqdm
14
+ from collections import deque
15
+ from optimum.quanto import freeze, qfloat8, quantize
16
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
17
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
18
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
19
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
20
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
21
+ import gradio as gr
22
+ from accelerate import Accelerator
23
+
24
+ # Instantiate the Accelerator
25
+ accelerator = Accelerator()
26
+
27
+ dtype = torch.bfloat16
28
+
29
+ # Set environment variables for local path
30
+ os.environ['FLUX_DEV'] = '.'
31
+ os.environ['AE'] = '.'
32
+
33
+ bfl_repo = 'black-forest-labs/FLUX.1-schnell'
34
+ revision = 'refs/pr/1'
35
+
36
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
37
+ text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
38
+ tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
39
+ text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
40
+ tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
41
+ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
42
+ transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
43
+
44
+ quantize(transformer, weights=qfloat8)
45
+ freeze(transformer)
46
+ quantize(text_encoder_2, weights=qfloat8)
47
+ freeze(text_encoder_2)
48
+
49
+ pipe = FluxPipeline(
50
+ scheduler=scheduler,
51
+ text_encoder=text_encoder,
52
+ tokenizer=tokenizer,
53
+ text_encoder_2=None,
54
+ tokenizer_2=tokenizer_2,
55
+ vae=vae,
56
+ transformer=None,
57
+ )
58
+ pipe.text_encoder_2 = text_encoder_2
59
+ pipe.transformer = transformer
60
+ pipe.enable_model_cpu_offload()
61
+
62
+ # Create a directory to save the generated images
63
+ output_dir = 'generated_images'
64
+ os.makedirs(output_dir, exist_ok=True)
65
+
66
+ # Function to generate a detailed visual description prompt
67
+ def generate_description_prompt(subject, user_prompt, text_generator):
68
+ prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
69
+ try:
70
+ generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
71
+ generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
72
+ return generated_description if generated_description else None
73
+ except Exception as e:
74
+ print(f"Error generating description for subject '{subject}': {e}")
75
+ return None
76
+
77
+ # Function to parse descriptions from a given text
78
+ def parse_descriptions(text):
79
+ # Find all descriptions enclosed in brackets
80
+ descriptions = re.findall(r'\[([^\[\]]+)\]', text)
81
+ # Filter descriptions with at least 3 words
82
+ descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
83
+ return descriptions
84
+
85
+ # Seed words pool
86
+ seed_words = []
87
+
88
+ used_words = set()
89
+ paused = False
90
+
91
+ # Queue to store parsed descriptions
92
+ parsed_descriptions_queue = deque()
93
+
94
+ # Usage limits
95
+ MAX_DESCRIPTIONS = 10
96
+ MAX_IMAGES = 5
97
+
98
+ @spaces.GPU
99
+ def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
100
+ global paused
101
+ descriptions = []
102
+ description_queue = deque()
103
+ iteration_count = 0
104
+
105
+ # Initialize the text generation pipeline with 16-bit precision
106
+ print("Initializing the text generation pipeline with 16-bit precision...")
107
+ model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
108
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
109
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
110
+ text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
111
+ print("Text generation pipeline initialized with 16-bit precision.")
112
+
113
+ # Populate the seed_words array with user input
114
+ seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
115
+
116
+ while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
117
+ if paused:
118
+ break
119
+
120
+ # Select a subject that has not been used
121
+ available_subjects = [word for word in seed_words if word not in used_words]
122
+ if not available_subjects:
123
+ print("No more available subjects to use.")
124
+ break
125
+
126
+ subject = random.choice(available_subjects)
127
+ generated_description = generate_description_prompt(subject, user_prompt, text_generator)
128
+
129
+ if generated_description:
130
+ # Remove any offending symbols
131
+ clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
132
+ description_queue.append({'subject': subject, 'description': clean_description})
133
+
134
+ # Print the generated description to the command line
135
+ print(f"Generated description for subject '{subject}': {clean_description}")
136
+
137
+ # Update used words and seed words
138
+ used_words.add(subject)
139
+ seed_words.append(clean_description) # Add the generated description to the seed bank array
140
+
141
+ # Parse and append descriptions every 3 iterations
142
+ if iteration_count % 3 == 0:
143
+ parsed_descriptions = parse_descriptions(clean_description)
144
+ parsed_descriptions_queue.extend(parsed_descriptions)
145
+
146
+ iteration_count += 1
147
+
148
+ return list(parsed_descriptions_queue)
149
+
150
+ @spaces.GPU(duration=120)
151
+ def generate_images(parsed_descriptions):
152
+ # If there are fewer than MAX_IMAGES descriptions, use whatever is available
153
+ if len(parsed_descriptions) < MAX_IMAGES:
154
+ prompts = parsed_descriptions
155
+ else:
156
+ prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)]
157
+
158
+ # Generate images from the parsed descriptions
159
+ images = []
160
+ for prompt in prompts:
161
+ images.extend(pipe(prompt, num_images=1).images)
162
+
163
+ return images
164
+
165
+ # Create Gradio Interface
166
+ def combined_function(user_prompt, seed_words_input):
167
+ parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
168
+ images = generate_images(parsed_descriptions)
169
+ return images
170
+
171
+ interface = gr.Interface(
172
+ fn=combined_function,
173
+ inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
174
+ outputs=gr.Gallery()
175
+ )
176
+
177
+ interface.launch()