Spaces:
Sleeping
Sleeping
import spaces # beginn | |
import torch.multiprocessing as mp | |
import torch | |
import os | |
import pandas as pd | |
import gc | |
import re | |
import random | |
from tqdm.auto import tqdm | |
from collections import deque | |
from optimum.quanto import freeze, qfloat8, quantize | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
from accelerate import Accelerator | |
# Check if the start method has already been set | |
if mp.get_start_method(allow_none=True) != 'spawn': | |
mp.set_start_method('spawn') | |
# Instantiate the Accelerator | |
accelerator = Accelerator() | |
dtype = torch.bfloat16 | |
# Set environment variables for local path | |
os.environ['FLUX_DEV'] = '.' | |
os.environ['AE'] = '.' | |
bfl_repo = 'black-forest-labs/FLUX.1-schnell' | |
revision = 'refs/pr/1' | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) | |
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) | |
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) | |
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) | |
quantize(transformer, weights=qfloat8) | |
freeze(transformer) | |
quantize(text_encoder_2, weights=qfloat8) | |
freeze(text_encoder_2) | |
pipe = FluxPipeline( | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
text_encoder_2=None, | |
tokenizer_2=tokenizer_2, | |
vae=vae, | |
transformer=None, | |
) | |
pipe.text_encoder_2 = text_encoder_2 | |
pipe.transformer = transformer | |
pipe.enable_model_cpu_offload() | |
# Create a directory to save the generated images | |
output_dir = 'generated_images' | |
os.makedirs(output_dir, exist_ok=True) | |
# Function to generate a detailed visual description prompt | |
def generate_description_prompt(subject, user_prompt, text_generator): | |
prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. " | |
try: | |
generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text'] | |
generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text | |
return generated_description if generated_description else None | |
except Exception as e: | |
print(f"Error generating description for subject '{subject}': {e}") | |
return None | |
# Function to parse descriptions from a given text | |
def parse_descriptions(text): | |
# Find all descriptions enclosed in brackets | |
descriptions = re.findall(r'\[([^\[\]]+)\]', text) | |
# Filter descriptions with at least 3 words | |
descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3] | |
return descriptions | |
# Seed words pool | |
seed_words = [] | |
used_words = set() | |
# Queue to store parsed descriptions | |
parsed_descriptions_queue = deque() | |
# Usage limits | |
MAX_DESCRIPTIONS = 30 | |
MAX_IMAGES = 12 | |
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50): | |
descriptions = [] | |
description_queue = deque() | |
iteration_count = 0 | |
# Initialize the text generation pipeline with 16-bit precision | |
print("Initializing the text generation pipeline with 16-bit precision...") | |
model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
print("Text generation pipeline initialized with 16-bit precision.") | |
# Populate the seed_words array with user input | |
seed_words.extend(re.findall(r'"(.*?)"', seed_words_input)) | |
while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS: | |
# Select a subject that has not been used | |
available_subjects = [word for word in seed_words if word not in used_words] | |
if not available_subjects: | |
print("No more available subjects to use.") | |
break | |
subject = random.choice(available_subjects) | |
generated_description = generate_description_prompt(subject, user_prompt, text_generator) | |
if generated_description: | |
# Remove any offending symbols | |
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') | |
description_queue.append({'subject': subject, 'description': clean_description}) | |
# Print the generated description to the command line | |
print(f"Generated description for subject '{subject}': {clean_description}") | |
# Update used words and seed words | |
used_words.add(subject) | |
seed_words.append(clean_description) # Add the generated description to the seed bank array | |
# Parse and append descriptions every 3 iterations | |
if iteration_count % 3 == 0: | |
parsed_descriptions = parse_descriptions(clean_description) | |
parsed_descriptions_queue.extend(parsed_descriptions) | |
iteration_count += 1 | |
return list(parsed_descriptions_queue) | |
def generate_images(parsed_descriptions): | |
# If there are fewer than MAX_IMAGES descriptions, use whatever is available | |
if len(parsed_descriptions) < MAX_IMAGES: | |
prompts = parsed_descriptions | |
else: | |
prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)] | |
# Generate images from the parsed descriptions | |
images = [] | |
for prompt in prompts: | |
images.extend(pipe(prompt, num_images=1).images) | |
return images | |
# Create Gradio Interface | |
def combined_function(user_prompt, seed_words_input): | |
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) | |
images = generate_images(parsed_descriptions) | |
return images | |
if __name__ == '__main__': | |
# Ensure CUDA is initialized correctly | |
torch.cuda.init() | |
interface = gr.Interface( | |
fn=combined_function, | |
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')], | |
outputs=gr.Gallery() | |
) | |
interface.launch() |