File size: 7,519 Bytes
5f554bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import multiprocessing

if __name__ == '__main__':
    multiprocessing.set_start_method('spawn')

    import spaces  # Import spaces at the very beginning
    import os
    import pandas as pd
    import torch
    import gc
    import re
    import random
    from tqdm.auto import tqdm
    from collections import deque
    from optimum.quanto import freeze, qfloat8, quantize
    from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
    from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
    from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
    from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
    from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
    import gradio as gr
    from accelerate import Accelerator

    # Instantiate the Accelerator
    accelerator = Accelerator()

    dtype = torch.bfloat16

    # Set environment variables for local path
    os.environ['FLUX_DEV'] = '.'
    os.environ['AE'] = '.'

    bfl_repo = 'black-forest-labs/FLUX.1-schnell'
    revision = 'refs/pr/1'

    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
    text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
    tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
    text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
    tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
    vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
    transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)

    quantize(transformer, weights=qfloat8)
    freeze(transformer)
    quantize(text_encoder_2, weights=qfloat8)
    freeze(text_encoder_2)

    pipe = FluxPipeline(
        scheduler=scheduler,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        text_encoder_2=None,
        tokenizer_2=tokenizer_2,
        vae=vae,
        transformer=None,
    )
    pipe.text_encoder_2 = text_encoder_2
    pipe.transformer = transformer
    pipe.enable_model_cpu_offload()

    # Create a directory to save the generated images
    output_dir = 'generated_images'
    os.makedirs(output_dir, exist_ok=True)

    # Function to generate a detailed visual description prompt
    def generate_description_prompt(subject, user_prompt, text_generator):
        prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
        try:
            generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
            generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip()  # Remove the prompt from the generated text
            return generated_description if generated_description else None
        except Exception as e:
            print(f"Error generating description for subject '{subject}': {e}")
            return None

    # Function to parse descriptions from a given text
    def parse_descriptions(text):
        # Find all descriptions enclosed in brackets
        descriptions = re.findall(r'\[([^\[\]]+)\]', text)
        # Filter descriptions with at least 3 words
        descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
        return descriptions

    # Seed words pool
    seed_words = []

    used_words = set()
    paused = False

    # Queue to store parsed descriptions
    parsed_descriptions_queue = deque()

    # Usage limits
    MAX_DESCRIPTIONS = 10
    MAX_IMAGES = 5

    @spaces.GPU
    def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
        global paused
        descriptions = []
        description_queue = deque()
        iteration_count = 0

        # Initialize the text generation pipeline with 16-bit precision
        print("Initializing the text generation pipeline with 16-bit precision...")
        model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
        print("Text generation pipeline initialized with 16-bit precision.")

        # Populate the seed_words array with user input
        seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))

        while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
            if paused:
                break

            # Select a subject that has not been used
            available_subjects = [word for word in seed_words if word not in used_words]
            if not available_subjects:
                print("No more available subjects to use.")
                break

            subject = random.choice(available_subjects)
            generated_description = generate_description_prompt(subject, user_prompt, text_generator)
            
            if generated_description:
                # Remove any offending symbols
                clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
                description_queue.append({'subject': subject, 'description': clean_description})

                # Print the generated description to the command line
                print(f"Generated description for subject '{subject}': {clean_description}")

                # Update used words and seed words
                used_words.add(subject)
                seed_words.append(clean_description)  # Add the generated description to the seed bank array

                # Parse and append descriptions every 3 iterations
                if iteration_count % 3 == 0:
                    parsed_descriptions = parse_descriptions(clean_description)
                    parsed_descriptions_queue.extend(parsed_descriptions)

            iteration_count += 1

        return list(parsed_descriptions_queue)

    @spaces.GPU(duration=120)
    def generate_images(parsed_descriptions):
        # If there are fewer than MAX_IMAGES descriptions, use whatever is available
        if len(parsed_descriptions) < MAX_IMAGES:
            prompts = parsed_descriptions
        else:
            prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)]

        # Generate images from the parsed descriptions
        images = []
        for prompt in prompts:
            images.extend(pipe(prompt, num_images=1).images)

        return images

    # Create Gradio Interface
    def combined_function(user_prompt, seed_words_input):
        parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
        images = generate_images(parsed_descriptions)
        return images

    interface = gr.Interface(
        fn=combined_function,
        inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
        outputs=gr.Gallery()
    )

    interface.launch()