File size: 6,707 Bytes
f8f7d35
 
 
 
 
 
 
 
 
 
 
 
 
 
acc9295
cbee3a7
f4476d7
cbee3a7
 
 
 
f8f7d35
acc9295
f8f7d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef6b0b4
af0d510
f8f7d35
ef6b0b4
f8f7d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef6b0b4
f8f7d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af0d510
 
f8f7d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import pandas as pd
import torch
import gc
import re
import random
from tqdm.auto import tqdm
from collections import deque
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
import huggingface_hub # Ensure this import is correct

# Instantiate the Accelerator
accelerator = Accelerator()

dtype = torch.bfloat16

# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'

bfl_repo = 'black-forest-labs/FLUX.1-schnell'
revision = 'refs/pr/1'

scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)

quantize(transformer, weights=qfloat8)
freeze(transformer)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline(
    scheduler=scheduler,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    text_encoder_2=None,
    tokenizer_2=tokenizer_2,
    vae=vae,
    transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()

# Create a directory to save the generated images
output_dir = 'generated_images'
os.makedirs(output_dir, exist_ok=True)

# Function to generate a detailed visual description prompt
def generate_description_prompt(subject, user_prompt, text_generator):
    prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
    try:
        generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
        generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip()  # Remove the prompt from the generated text
        return generated_description if generated_description else None
    except Exception as e:
        print(f"Error generating description for subject '{subject}': {e}")
        return None

# Function to parse descriptions from a given text
def parse_descriptions(text):
    # Find all descriptions enclosed in brackets
    descriptions = re.findall(r'\[([^\[\]]+)\]', text)
    # Filter descriptions with at least 3 words
    descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
    return descriptions

# Seed words pool
seed_words = []

used_words = set()
paused = False

# Queue to store parsed descriptions
parsed_descriptions_queue = deque()

@spaces.GPU
def generate_and_store_descriptions(user_prompt, batch_size=100, max_iterations=50):
    global paused
    descriptions = []
    description_queue = deque()
    iteration_count = 0

    # Initialize the text generation pipeline with 16-bit precision
    print("Initializing the text generation pipeline with 16-bit precision...")
    model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
    print("Text generation pipeline initialized with 16-bit precision.")

    while iteration_count < max_iterations:
        if paused:
            break

        # Select a subject that has not been used
        available_subjects = [word for word in seed_words if word not in used_words]
        if not available_subjects:
            print("No more available subjects to use.")
            break

        subject = random.choice(available_subjects)
        generated_description = generate_description_prompt(subject, user_prompt, text_generator)
        
        if generated_description:
            # Remove any offending symbols
            clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
            description_queue.append({'subject': subject, 'description': clean_description})

            # Print the generated description to the command line
            print(f"Generated description for subject '{subject}': {clean_description}")

            # Update used words and seed words
            used_words.add(subject)
            seed_words.append(clean_description)  # Add the generated description to the seed bank array

            # Parse and append descriptions every 3 iterations
            if iteration_count % 3 == 0:
                parsed_descriptions = parse_descriptions(clean_description)
                parsed_descriptions_queue.extend(parsed_descriptions)
                # Return the parsed descriptions to update the Gradio UI
                return list(parsed_descriptions_queue)

        iteration_count += 1

    return list(parsed_descriptions_queue)

@spaces.GPU(duration=120)
def generate_images_from_parsed_descriptions():
    # If there are fewer than 13 descriptions, use whatever is available
    if len(parsed_descriptions_queue) < 13:
        prompts = list(parsed_descriptions_queue)
    else:
        prompts = [parsed_descriptions_queue.popleft() for _ in range(13)]

    # Generate images from the parsed descriptions
    images = []
    for prompt in prompts:
        images.extend(pipe(prompt, num_images=1).images)

    return images

# Create Gradio Interface
description_interface = gr.Interface(
    fn=generate_and_store_descriptions,
    inputs=gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."),
    outputs="json"
)

image_interface = gr.Interface(
    fn=generate_images_from_parsed_descriptions,
    inputs=None,
    outputs=gr.Gallery()
)

gr.TabbedInterface([description_interface, image_interface], ["Generate Descriptions", "Generate Images"]).launch()