import spaces # beginn import torch.multiprocessing as mp import torch import os import pandas as pd import gc import re import random from tqdm.auto import tqdm from collections import deque from optimum.quanto import freeze, qfloat8, quantize from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import gradio as gr from accelerate import Accelerator # Check if the start method has already been set if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn') # Instantiate the Accelerator accelerator = Accelerator() dtype = torch.bfloat16 # Set environment variables for local path os.environ['FLUX_DEV'] = '.' os.environ['AE'] = '.' bfl_repo = 'black-forest-labs/FLUX.1-schnell' revision = 'refs/pr/1' scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) quantize(transformer, weights=qfloat8) freeze(transformer) quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) pipe = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=None, tokenizer_2=tokenizer_2, vae=vae, transformer=None, ) pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer pipe.enable_model_cpu_offload() # Create a directory to save the generated images output_dir = 'generated_images' os.makedirs(output_dir, exist_ok=True) # Function to generate a detailed visual description prompt def generate_description_prompt(subject, user_prompt, text_generator): prompt = f"write concise vivid visual description enclosed in brackets like [ ] less than 100 words of {user_prompt} different from {subject}. " try: generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text'] generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text return generated_description if generated_description else None except Exception as e: print(f"Error generating description for subject '{subject}': {e}") return None # Function to parse descriptions from a given text def parse_descriptions(text): # Find all descriptions enclosed in brackets descriptions = re.findall(r'\[([^\[\]]+)\]', text) # Filter descriptions with at least 3 words descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3] return descriptions # Seed words pool seed_words = [] used_words = set() # Queue to store parsed descriptions parsed_descriptions_queue = deque() # Usage limits MAX_DESCRIPTIONS = 30 MAX_IMAGES = 12 @spaces.GPU def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50): descriptions = [] description_queue = deque() iteration_count = 0 # Initialize the text generation pipeline with 16-bit precision print("Initializing the text generation pipeline with 16-bit precision...") model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_name) text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) print("Text generation pipeline initialized with 16-bit precision.") # Populate the seed_words array with user input seed_words.extend(re.findall(r'"(.*?)"', seed_words_input)) while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS: # Select a subject that has not been used available_subjects = [word for word in seed_words if word not in used_words] if not available_subjects: print("No more available subjects to use.") break subject = random.choice(available_subjects) generated_description = generate_description_prompt(subject, user_prompt, text_generator) if generated_description: # Remove any offending symbols clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') description_queue.append({'subject': subject, 'description': clean_description}) # Print the generated description to the command line print(f"Generated description for subject '{subject}': {clean_description}") # Update used words and seed words used_words.add(subject) seed_words.append(clean_description) # Add the generated description to the seed bank array # Parse and append descriptions every 3 iterations if iteration_count % 3 == 0: parsed_descriptions = parse_descriptions(clean_description) parsed_descriptions_queue.extend(parsed_descriptions) iteration_count += 1 return list(parsed_descriptions_queue) @spaces.GPU(duration=120) def generate_images(parsed_descriptions): # If there are fewer than MAX_IMAGES descriptions, use whatever is available if len(parsed_descriptions) < MAX_IMAGES: prompts = parsed_descriptions else: prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)] # Generate images from the parsed descriptions images = [] for prompt in prompts: images.extend(pipe(prompt, num_images=1).images) return images # Create Gradio Interface def combined_function(user_prompt, seed_words_input): parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) images = generate_images(parsed_descriptions) return images if __name__ == '__main__': # Ensure CUDA is initialized correctly torch.cuda.init() interface = gr.Interface( fn=combined_function, inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')], outputs=gr.Gallery() ) interface.launch()