patrickbdevaney commited on
Commit
4e91b3f
·
verified ·
1 Parent(s): 5f554bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -167
app.py CHANGED
@@ -1,172 +1,172 @@
1
- import multiprocessing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == '__main__':
4
- multiprocessing.set_start_method('spawn')
5
-
6
- import spaces # Import spaces at the very beginning
7
- import os
8
- import pandas as pd
9
- import torch
10
- import gc
11
- import re
12
- import random
13
- from tqdm.auto import tqdm
14
- from collections import deque
15
- from optimum.quanto import freeze, qfloat8, quantize
16
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
17
- from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
18
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
19
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
20
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
21
- import gradio as gr
22
- from accelerate import Accelerator
23
-
24
- # Instantiate the Accelerator
25
- accelerator = Accelerator()
26
-
27
- dtype = torch.bfloat16
28
-
29
- # Set environment variables for local path
30
- os.environ['FLUX_DEV'] = '.'
31
- os.environ['AE'] = '.'
32
-
33
- bfl_repo = 'black-forest-labs/FLUX.1-schnell'
34
- revision = 'refs/pr/1'
35
-
36
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
37
- text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
38
- tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
39
- text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
40
- tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
41
- vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
42
- transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
43
-
44
- quantize(transformer, weights=qfloat8)
45
- freeze(transformer)
46
- quantize(text_encoder_2, weights=qfloat8)
47
- freeze(text_encoder_2)
48
-
49
- pipe = FluxPipeline(
50
- scheduler=scheduler,
51
- text_encoder=text_encoder,
52
- tokenizer=tokenizer,
53
- text_encoder_2=None,
54
- tokenizer_2=tokenizer_2,
55
- vae=vae,
56
- transformer=None,
57
- )
58
- pipe.text_encoder_2 = text_encoder_2
59
- pipe.transformer = transformer
60
- pipe.enable_model_cpu_offload()
61
-
62
- # Create a directory to save the generated images
63
- output_dir = 'generated_images'
64
- os.makedirs(output_dir, exist_ok=True)
65
-
66
- # Function to generate a detailed visual description prompt
67
- def generate_description_prompt(subject, user_prompt, text_generator):
68
- prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
69
- try:
70
- generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
71
- generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
72
- return generated_description if generated_description else None
73
- except Exception as e:
74
- print(f"Error generating description for subject '{subject}': {e}")
75
- return None
76
-
77
- # Function to parse descriptions from a given text
78
- def parse_descriptions(text):
79
- # Find all descriptions enclosed in brackets
80
- descriptions = re.findall(r'\[([^\[\]]+)\]', text)
81
- # Filter descriptions with at least 3 words
82
- descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
83
- return descriptions
84
-
85
- # Seed words pool
86
- seed_words = []
87
-
88
- used_words = set()
89
- paused = False
90
-
91
- # Queue to store parsed descriptions
92
- parsed_descriptions_queue = deque()
93
-
94
- # Usage limits
95
- MAX_DESCRIPTIONS = 10
96
- MAX_IMAGES = 5
97
-
98
- @spaces.GPU
99
- def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
100
- global paused
101
- descriptions = []
102
- description_queue = deque()
103
- iteration_count = 0
104
-
105
- # Initialize the text generation pipeline with 16-bit precision
106
- print("Initializing the text generation pipeline with 16-bit precision...")
107
- model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
108
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
109
- tokenizer = AutoTokenizer.from_pretrained(model_name)
110
- text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
111
- print("Text generation pipeline initialized with 16-bit precision.")
112
-
113
- # Populate the seed_words array with user input
114
- seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
115
-
116
- while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
117
- if paused:
118
- break
119
-
120
- # Select a subject that has not been used
121
- available_subjects = [word for word in seed_words if word not in used_words]
122
- if not available_subjects:
123
- print("No more available subjects to use.")
124
- break
125
-
126
- subject = random.choice(available_subjects)
127
- generated_description = generate_description_prompt(subject, user_prompt, text_generator)
128
-
129
- if generated_description:
130
- # Remove any offending symbols
131
- clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
132
- description_queue.append({'subject': subject, 'description': clean_description})
133
-
134
- # Print the generated description to the command line
135
- print(f"Generated description for subject '{subject}': {clean_description}")
136
-
137
- # Update used words and seed words
138
- used_words.add(subject)
139
- seed_words.append(clean_description) # Add the generated description to the seed bank array
140
-
141
- # Parse and append descriptions every 3 iterations
142
- if iteration_count % 3 == 0:
143
- parsed_descriptions = parse_descriptions(clean_description)
144
- parsed_descriptions_queue.extend(parsed_descriptions)
145
-
146
- iteration_count += 1
147
-
148
- return list(parsed_descriptions_queue)
149
-
150
- @spaces.GPU(duration=120)
151
- def generate_images(parsed_descriptions):
152
- # If there are fewer than MAX_IMAGES descriptions, use whatever is available
153
- if len(parsed_descriptions) < MAX_IMAGES:
154
- prompts = parsed_descriptions
155
- else:
156
- prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)]
157
-
158
- # Generate images from the parsed descriptions
159
- images = []
160
- for prompt in prompts:
161
- images.extend(pipe(prompt, num_images=1).images)
162
-
163
- return images
164
-
165
- # Create Gradio Interface
166
- def combined_function(user_prompt, seed_words_input):
167
- parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
168
- images = generate_images(parsed_descriptions)
169
- return images
170
 
171
  interface = gr.Interface(
172
  fn=combined_function,
@@ -175,3 +175,4 @@ if __name__ == '__main__':
175
  )
176
 
177
  interface.launch()
 
 
1
+ import spaces # beginn
2
+ import torch.multiprocessing as mp
3
+ import torch
4
+ import os
5
+ import pandas as pd
6
+ import gc
7
+ import re
8
+ import random
9
+ from tqdm.auto import tqdm
10
+ from collections import deque
11
+ from optimum.quanto import freeze, qfloat8, quantize
12
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
13
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
14
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
15
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
16
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
17
+ import gradio as gr
18
+ from accelerate import Accelerator
19
+
20
+ # Instantiate the Accelerator
21
+ accelerator = Accelerator()
22
+
23
+ dtype = torch.bfloat16
24
+
25
+ # Set environment variables for local path
26
+ os.environ['FLUX_DEV'] = '.'
27
+ os.environ['AE'] = '.'
28
+
29
+ bfl_repo = 'black-forest-labs/FLUX.1-schnell'
30
+ revision = 'refs/pr/1'
31
+
32
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
33
+ text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
34
+ tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
35
+ text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
36
+ tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
37
+ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
38
+ transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
39
+
40
+ quantize(transformer, weights=qfloat8)
41
+ freeze(transformer)
42
+ quantize(text_encoder_2, weights=qfloat8)
43
+ freeze(text_encoder_2)
44
+
45
+ pipe = FluxPipeline(
46
+ scheduler=scheduler,
47
+ text_encoder=text_encoder,
48
+ tokenizer=tokenizer,
49
+ text_encoder_2=None,
50
+ tokenizer_2=tokenizer_2,
51
+ vae=vae,
52
+ transformer=None,
53
+ )
54
+ pipe.text_encoder_2 = text_encoder_2
55
+ pipe.transformer = transformer
56
+ pipe.enable_model_cpu_offload()
57
+
58
+ # Create a directory to save the generated images
59
+ output_dir = 'generated_images'
60
+ os.makedirs(output_dir, exist_ok=True)
61
+
62
+ # Function to generate a detailed visual description prompt
63
+ def generate_description_prompt(subject, user_prompt, text_generator):
64
+ prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
65
+ try:
66
+ generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
67
+ generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
68
+ return generated_description if generated_description else None
69
+ except Exception as e:
70
+ print(f"Error generating description for subject '{subject}': {e}")
71
+ return None
72
+
73
+ # Function to parse descriptions from a given text
74
+ def parse_descriptions(text):
75
+ # Find all descriptions enclosed in brackets
76
+ descriptions = re.findall(r'\[([^\[\]]+)\]', text)
77
+ # Filter descriptions with at least 3 words
78
+ descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
79
+ return descriptions
80
+
81
+ # Seed words pool
82
+ seed_words = []
83
+
84
+ used_words = set()
85
+ paused = False
86
+
87
+ # Queue to store parsed descriptions
88
+ parsed_descriptions_queue = deque()
89
+
90
+ # Usage limits
91
+ MAX_DESCRIPTIONS = 10
92
+ MAX_IMAGES = 5
93
+
94
+ @spaces.GPU
95
+ def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
96
+ global paused
97
+ descriptions = []
98
+ description_queue = deque()
99
+ iteration_count = 0
100
+
101
+ # Initialize the text generation pipeline with 16-bit precision
102
+ print("Initializing the text generation pipeline with 16-bit precision...")
103
+ model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
104
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
105
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
106
+ text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
107
+ print("Text generation pipeline initialized with 16-bit precision.")
108
+
109
+ # Populate the seed_words array with user input
110
+ seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
111
+
112
+ while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
113
+ if paused:
114
+ break
115
+
116
+ # Select a subject that has not been used
117
+ available_subjects = [word for word in seed_words if word not in used_words]
118
+ if not available_subjects:
119
+ print("No more available subjects to use.")
120
+ break
121
+
122
+ subject = random.choice(available_subjects)
123
+ generated_description = generate_description_prompt(subject, user_prompt, text_generator)
124
+
125
+ if generated_description:
126
+ # Remove any offending symbols
127
+ clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
128
+ description_queue.append({'subject': subject, 'description': clean_description})
129
+
130
+ # Print the generated description to the command line
131
+ print(f"Generated description for subject '{subject}': {clean_description}")
132
+
133
+ # Update used words and seed words
134
+ used_words.add(subject)
135
+ seed_words.append(clean_description) # Add the generated description to the seed bank array
136
+
137
+ # Parse and append descriptions every 3 iterations
138
+ if iteration_count % 3 == 0:
139
+ parsed_descriptions = parse_descriptions(clean_description)
140
+ parsed_descriptions_queue.extend(parsed_descriptions)
141
+
142
+ iteration_count += 1
143
+
144
+ return list(parsed_descriptions_queue)
145
+
146
+ @spaces.GPU(duration=120)
147
+ def generate_images(parsed_descriptions):
148
+ # If there are fewer than MAX_IMAGES descriptions, use whatever is available
149
+ if len(parsed_descriptions) < MAX_IMAGES:
150
+ prompts = parsed_descriptions
151
+ else:
152
+ prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)]
153
+
154
+ # Generate images from the parsed descriptions
155
+ images = []
156
+ for prompt in prompts:
157
+ images.extend(pipe(prompt, num_images=1).images)
158
+
159
+ return images
160
+
161
+ # Create Gradio Interface
162
+ def combined_function(user_prompt, seed_words_input):
163
+ parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
164
+ images = generate_images(parsed_descriptions)
165
+ return images
166
 
167
  if __name__ == '__main__':
168
+ mp.set_start_method('spawn')
169
+ initialize_cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  interface = gr.Interface(
172
  fn=combined_function,
 
175
  )
176
 
177
  interface.launch()
178
+