File size: 7,063 Bytes
4e91b3f
 
 
 
 
 
 
 
 
e66b780
4e91b3f
4122838
 
 
 
4e91b3f
 
 
 
 
 
 
 
4387c36
4e91b3f
567999c
 
 
 
 
 
 
 
 
 
519d719
567999c
1716c9d
 
 
 
 
 
 
87f446f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1716c9d
87f446f
1716c9d
567999c
4e91b3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519d719
 
 
 
4e91b3f
d5f6733
4e91b3f
 
 
 
 
 
2843c6f
 
 
 
 
 
4e91b3f
2843c6f
 
4e91b3f
2843c6f
 
 
4e91b3f
2843c6f
 
 
 
4e91b3f
 
 
 
2843c6f
4e91b3f
 
 
1716c9d
519d719
7fbcaaf
 
 
4e91b3f
 
7fbcaaf
 
4e91b3f
 
 
 
 
519d719
e66b780
519d719
5f554bd
 
8b7ec26
519d719
 
5f554bd
8b7ec26
 
 
776a3ec
292be6f
25d995e
2beec1e
 
519d719
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch.multiprocessing as mp
import torch
import os
import re
import random
from collections import deque
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
import spaces

# Check if the start method has already been set
if mp.get_start_method(allow_none=True) != 'spawn':
    mp.set_start_method('spawn')

# Instantiate the Accelerator
accelerator = Accelerator()

dtype = torch.bfloat16

# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'false'  # Disable HF_HUB_ENABLE_HF_TRANSFER

# Seed words pool
seed_words = []

used_words = set()

# Queue to store parsed descriptions
parsed_descriptions_queue = deque()

# Usage limits
MAX_DESCRIPTIONS = 30
MAX_IMAGES = 3  # Limit to 3 images

# Preload models and checkpoints
print("Preloading models and checkpoints...")
model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)

def initialize_diffusers():
    from optimum.quanto import freeze, qfloat8, quantize
    from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
    from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
    from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
    from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast

    bfl_repo = 'black-forest-labs/FLUX.1-schnell'
    revision = 'refs/pr/1'

    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
    text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
    tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
    text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
    tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
    vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
    transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)

    quantize(transformer, weights=qfloat8)
    freeze(transformer)
    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)

    pipe = FluxPipeline(
        scheduler=scheduler,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        text_encoder_2=None,
        tokenizer_2=tokenizer_2,
        vae=vae,
        transformer=None,
    )
    pipe.text_encoder_2 = text_encoder_2
    pipe.transformer = transformer
    pipe.enable_model_cpu_offload()

    return pipe

pipe = initialize_diffusers()
print("Models and checkpoints preloaded.")

def generate_description_prompt(subject, user_prompt, text_generator):
    prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
    try:
        generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
        generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip()  # Remove the prompt from the generated text
        return generated_description if generated_description else None
    except Exception as e:
        print(f"Error generating description for subject '{subject}': {e}")
        return None

def parse_descriptions(text):
    descriptions = re.findall(r'\[([^\[\]]+)\]', text)
    descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
    return descriptions

def format_descriptions(descriptions):
    formatted_descriptions = "\n".join(descriptions)
    return formatted_descriptions

@spaces.GPU
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1):  # Set max_iterations to 1
    descriptions = []
    description_queue = deque()
    iteration_count = 0

    seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))

    for _ in range(2):  # Perform two iterations
        while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
            available_subjects = [word for word in seed_words if word not in used_words]
            if not available_subjects:
                print("No more available subjects to use.")
                break

            subject = random.choice(available_subjects)
            generated_description = generate_description_prompt(subject, user_prompt, text_generator)

            if generated_description:
                clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
                description_queue.append({'subject': subject, 'description': clean_description})

                print(f"Generated description for subject '{subject}': {clean_description}")

                used_words.add(subject)
                seed_words.append(clean_description)

                parsed_descriptions = parse_descriptions(clean_description)
                parsed_descriptions_queue.extend(parsed_descriptions)

            iteration_count += 1

    return list(parsed_descriptions_queue)

@spaces.GPU(duration=120)
def generate_images(parsed_descriptions, max_iterations=3):  # Set max_iterations to 3
    # Limit the number of descriptions passed to the image generator to 2
    if len(parsed_descriptions) > MAX_IMAGES:
        parsed_descriptions = parsed_descriptions[:MAX_IMAGES]

    images = []
    for prompt in parsed_descriptions:
        images.extend(pipe(prompt, num_inference_steps=max_iterations, height=512, width=512).images)  # Set resolution to 512 x 512

    return images

def combined_function(user_prompt, seed_words_input):
    parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
    formatted_descriptions = format_descriptions(parsed_descriptions)
    images = generate_images(parsed_descriptions)
    return formatted_descriptions, images

if __name__ == '__main__':
    def generate_and_display(user_prompt, seed_words_input):
        formatted_descriptions, images = combined_function(user_prompt, seed_words_input)
        return formatted_descriptions, images

    interface = gr.Interface(
        fn=generate_and_display,
        inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
        outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
        live=False,  # Set live to False
        allow_flagging='never'  # Disable flagging
    )

    interface.launch(share=True)