patrickbdevaney commited on
Commit
f8f7d35
·
verified ·
1 Parent(s): acc9295

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -4
app.py CHANGED
@@ -1,7 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import gc
5
+ import re
6
+ import random
7
+ from tqdm.auto import tqdm
8
+ from collections import deque
9
+ from optimum.quanto import freeze, qfloat8, quantize
10
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
11
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
12
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
13
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
14
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
15
  import gradio as gr
16
 
17
+ dtype = torch.bfloat16
 
18
 
19
+ # Set environment variables for local path
20
+ os.environ['FLUX_DEV'] = '.'
21
+ os.environ['AE'] = '.'
22
+
23
+ bfl_repo = 'black-forest-labs/FLUX.1-schnell'
24
+ revision = 'refs/pr/1'
25
+
26
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
27
+ text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
28
+ tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
29
+ text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
30
+ tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
31
+ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
32
+ transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
33
+
34
+ quantize(transformer, weights=qfloat8)
35
+ freeze(transformer)
36
+ quantize(text_encoder_2, weights=qfloat8)
37
+ freeze(text_encoder_2)
38
+
39
+ pipe = FluxPipeline(
40
+ scheduler=scheduler,
41
+ text_encoder=text_encoder,
42
+ tokenizer=tokenizer,
43
+ text_encoder_2=None,
44
+ tokenizer_2=tokenizer_2,
45
+ vae=vae,
46
+ transformer=None,
47
+ )
48
+ pipe.text_encoder_2 = text_encoder_2
49
+ pipe.transformer = transformer
50
+ pipe.enable_model_cpu_offload()
51
+
52
+ # Create a directory to save the generated images
53
+ output_dir = 'generated_images'
54
+ os.makedirs(output_dir, exist_ok=True)
55
+
56
+ # Function to generate a detailed visual description prompt
57
+ def generate_description_prompt(subject, user_prompt):
58
+ prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
59
+ try:
60
+ generated_text = text_generator(prompt, max_length=230, num_return_sequences=1, truncation=True)[0]['generated_text']
61
+ generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
62
+ return generated_description if generated_description else None
63
+ except Exception as e:
64
+ print(f"Error generating description for subject '{subject}': {e}")
65
+ return None
66
+
67
+ # Function to parse descriptions from a given text
68
+ def parse_descriptions(text):
69
+ # Find all descriptions enclosed in brackets
70
+ descriptions = re.findall(r'\[([^\[\]]+)\]', text)
71
+ # Filter descriptions with at least 3 words
72
+ descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
73
+ return descriptions
74
+
75
+ # Seed words pool
76
+ seed_words = []
77
+
78
+ used_words = set()
79
+ paused = False
80
+
81
+ # Queue to store parsed descriptions
82
+ parsed_descriptions_queue = deque()
83
+
84
+ @spaces.GPU
85
+ def generate_and_store_descriptions(user_prompt, batch_size=100, max_iterations=50):
86
+ global paused
87
+ descriptions = []
88
+ description_queue = deque()
89
+ iteration_count = 0
90
+
91
+ # Initialize the text generation pipeline with 16-bit precision
92
+ print("Initializing the text generation pipeline with 16-bit precision...")
93
+ model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
94
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
96
+ text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
97
+ print("Text generation pipeline initialized with 16-bit precision.")
98
+
99
+ while iteration_count < max_iterations:
100
+ if paused:
101
+ break
102
+
103
+ # Select a subject that has not been used
104
+ available_subjects = [word for word in seed_words if word not in used_words]
105
+ if not available_subjects:
106
+ print("No more available subjects to use.")
107
+ break
108
+
109
+ subject = random.choice(available_subjects)
110
+ generated_description = generate_description_prompt(subject, user_prompt)
111
+
112
+ if generated_description:
113
+ # Remove any offending symbols
114
+ clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
115
+ description_queue.append({'subject': subject, 'description': clean_description})
116
+
117
+ # Print the generated description to the command line
118
+ print(f"Generated description for subject '{subject}': {clean_description}")
119
+
120
+ # Update used words and seed words
121
+ used_words.add(subject)
122
+ seed_words.append(clean_description) # Add the generated description to the seed bank array
123
+
124
+ # Parse and append descriptions every 3 iterations
125
+ if iteration_count % 3 == 0:
126
+ parsed_descriptions = parse_descriptions(clean_description)
127
+ parsed_descriptions_queue.extend(parsed_descriptions)
128
+
129
+ iteration_count += 1
130
+
131
+ return list(parsed_descriptions_queue)
132
+
133
+ @spaces.GPU(duration=120)
134
+ def generate_images_from_parsed_descriptions():
135
+ # If there are fewer than 13 descriptions, use whatever is available
136
+ if len(parsed_descriptions_queue) < 13:
137
+ prompts = list(parsed_descriptions_queue)
138
+ else:
139
+ prompts = [parsed_descriptions_queue.popleft() for _ in range(13)]
140
+
141
+ # Generate images from the parsed descriptions
142
+ images = []
143
+ for prompt in prompts:
144
+ images.extend(pipe(prompt, num_images=1).images)
145
+
146
+ return images
147
+
148
+ # Create Gradio Interface
149
+ description_interface = gr.Interface(
150
+ fn=generate_and_store_descriptions,
151
+ inputs=gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."),
152
+ outputs="json"
153
+ )
154
+
155
+ image_interface = gr.Interface(
156
+ fn=generate_images_from_parsed_descriptions,
157
+ inputs=None,
158
+ outputs=gr.Gallery()
159
+ )
160
+
161
+ gr.TabbedInterface([description_interface, image_interface], ["Generate Descriptions", "Generate Images"]).launch()