Spaces:
Sleeping
Sleeping
File size: 7,043 Bytes
4e91b3f 4122838 4e91b3f 4122838 4e91b3f 5f554bd 53ab3a6 5f554bd 4122838 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import spaces # beginn
import torch.multiprocessing as mp
import torch
import os
import pandas as pd
import gc
import re
import random
from tqdm.auto import tqdm
from collections import deque
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
# Check if the start method has already been set
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
# Instantiate the Accelerator
accelerator = Accelerator()
dtype = torch.bfloat16
# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'
bfl_repo = 'black-forest-labs/FLUX.1-schnell'
revision = 'refs/pr/1'
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
quantize(transformer, weights=qfloat8)
freeze(transformer)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
# Create a directory to save the generated images
output_dir = 'generated_images'
os.makedirs(output_dir, exist_ok=True)
# Function to generate a detailed visual description prompt
def generate_description_prompt(subject, user_prompt, text_generator):
prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
try:
generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
return generated_description if generated_description else None
except Exception as e:
print(f"Error generating description for subject '{subject}': {e}")
return None
# Function to parse descriptions from a given text
def parse_descriptions(text):
# Find all descriptions enclosed in brackets
descriptions = re.findall(r'\[([^\[\]]+)\]', text)
# Filter descriptions with at least 3 words
descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
return descriptions
# Seed words pool
seed_words = []
used_words = set()
# Queue to store parsed descriptions
parsed_descriptions_queue = deque()
# Usage limits
MAX_DESCRIPTIONS = 30
MAX_IMAGES = 12
@spaces.GPU
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50):
descriptions = []
description_queue = deque()
iteration_count = 0
# Initialize the text generation pipeline with 16-bit precision
print("Initializing the text generation pipeline with 16-bit precision...")
model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
print("Text generation pipeline initialized with 16-bit precision.")
# Populate the seed_words array with user input
seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
# Select a subject that has not been used
available_subjects = [word for word in seed_words if word not in used_words]
if not available_subjects:
print("No more available subjects to use.")
break
subject = random.choice(available_subjects)
generated_description = generate_description_prompt(subject, user_prompt, text_generator)
if generated_description:
# Remove any offending symbols
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
description_queue.append({'subject': subject, 'description': clean_description})
# Print the generated description to the command line
print(f"Generated description for subject '{subject}': {clean_description}")
# Update used words and seed words
used_words.add(subject)
seed_words.append(clean_description) # Add the generated description to the seed bank array
# Parse and append descriptions every 3 iterations
if iteration_count % 3 == 0:
parsed_descriptions = parse_descriptions(clean_description)
parsed_descriptions_queue.extend(parsed_descriptions)
iteration_count += 1
return list(parsed_descriptions_queue)
@spaces.GPU(duration=120)
def generate_images(parsed_descriptions):
# If there are fewer than MAX_IMAGES descriptions, use whatever is available
if len(parsed_descriptions) < MAX_IMAGES:
prompts = parsed_descriptions
else:
prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)]
# Generate images from the parsed descriptions
images = []
for prompt in prompts:
images.extend(pipe(prompt, num_images=1).images)
return images
# Create Gradio Interface
def combined_function(user_prompt, seed_words_input):
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
images = generate_images(parsed_descriptions)
return images
if __name__ == '__main__':
# Ensure CUDA is initialized correctly
torch.cuda.init()
interface = gr.Interface(
fn=combined_function,
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
outputs=gr.Gallery()
)
interface.launch() |