Spaces:
Sleeping
Sleeping
import torch.multiprocessing as mp | |
import torch | |
import os | |
import re | |
import random | |
from collections import deque | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
from accelerate import Accelerator | |
import spaces | |
# Check if the start method has already been set | |
if mp.get_start_method(allow_none=True) != 'spawn': | |
mp.set_start_method('spawn') | |
# Instantiate the Accelerator | |
accelerator = Accelerator() | |
dtype = torch.bfloat16 | |
# Set environment variables for local path | |
os.environ['FLUX_DEV'] = '.' | |
os.environ['AE'] = '.' | |
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'false' # Disable HF_HUB_ENABLE_HF_TRANSFER | |
# Seed words pool | |
seed_words = [] | |
used_words = set() | |
# Queue to store parsed descriptions | |
parsed_descriptions_queue = deque() | |
# Usage limits | |
MAX_DESCRIPTIONS = 30 | |
MAX_IMAGES = 4 # Limit to 4 images | |
# Preload models and checkpoints | |
print("Preloading models and checkpoints...") | |
model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct' | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
def initialize_diffusers(): | |
from optimum.quanto import freeze, qfloat8, quantize | |
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL | |
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast | |
bfl_repo = 'black-forest-labs/FLUX.1-schnell' | |
revision = 'refs/pr/1' | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) | |
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) | |
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) | |
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) | |
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) | |
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) | |
quantize(transformer, weights=qfloat8) | |
freeze(transformer) | |
quantize(text_encoder_2, weights=qfloat8) | |
freeze(text_encoder_2) | |
pipe = FluxPipeline( | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
text_encoder_2=None, | |
tokenizer_2=tokenizer_2, | |
vae=vae, | |
transformer=None, | |
) | |
pipe.text_encoder_2 = text_encoder_2 | |
pipe.transformer = transformer | |
pipe.enable_model_cpu_offload() | |
return pipe | |
pipe = initialize_diffusers() | |
print("Models and checkpoints preloaded.") | |
def generate_description_prompt(user_prompt, text_generator): | |
injected_prompt = f"write three concise descriptions enclosed in brackets like [ <description> ] less than 100 words each of {user_prompt}. " | |
max_length = 110 # Set max token length to 110 | |
try: | |
generated_text = text_generator(injected_prompt, max_length=max_length, num_return_sequences=1, truncation=True)[0]['generated_text'] | |
generated_descriptions = re.findall(r'\[([^\[\]]+)\]', generated_text) # Extract descriptions enclosed in brackets | |
# Filter descriptions to ensure they are at least 4 words long | |
filtered_descriptions = [desc for desc in generated_descriptions if len(desc.split()) >= 4] | |
return filtered_descriptions if filtered_descriptions else None | |
except Exception as e: | |
print(f"Error generating descriptions: {e}") | |
return None | |
def format_descriptions(descriptions): | |
formatted_descriptions = "\n".join(descriptions) | |
return formatted_descriptions | |
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=4): # Set max_iterations to 4 | |
descriptions = [] | |
for _ in range(4): # Perform four iterations | |
new_descriptions = generate_description_prompt(user_prompt, text_generator) | |
if new_descriptions: | |
descriptions.extend(new_descriptions) | |
# Pick a random description to feed back into the seed bank for subject | |
random_description = random.choice(new_descriptions) | |
seed_words.append(random_description) | |
# Limit the number of descriptions to MAX_IMAGES (4) | |
if len(descriptions) > MAX_IMAGES: | |
descriptions = descriptions[:MAX_IMAGES] | |
parsed_descriptions_queue.extend(descriptions) | |
return list(parsed_descriptions_queue)[:MAX_IMAGES] | |
def generate_images(parsed_descriptions, max_iterations=4): # Set max_iterations to 4 | |
# Limit the number of descriptions passed to the image generator to MAX_IMAGES (4) | |
if len(parsed_descriptions) > MAX_IMAGES: | |
parsed_descriptions = parsed_descriptions[:MAX_IMAGES] | |
images = [] | |
for prompt in parsed_descriptions: | |
try: | |
images.extend(pipe(prompt, num_inference_steps=4, height=1024, width=1024).images) # Set resolution to 1024 x 1024 | |
except Exception as e: | |
print(f"Error generating image for prompt '{prompt}': {e}") | |
return images | |
def combined_function(user_prompt, seed_words_input): | |
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) | |
formatted_descriptions = format_descriptions(parsed_descriptions) | |
images = generate_images(parsed_descriptions) | |
return formatted_descriptions, images | |
if __name__ == '__main__': | |
def generate_and_display(user_prompt, seed_words_input): | |
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) | |
formatted_descriptions = format_descriptions(parsed_descriptions) | |
images = generate_images(parsed_descriptions) | |
return formatted_descriptions, images | |
interface = gr.Interface( | |
fn=generate_and_display, | |
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter example in quotes, e.g., "cat", "dog", "sunset"...')], | |
outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")], | |
live=False, # Set live to False | |
allow_flagging='never' # Disable flagging | |
) | |
interface.launch(share=True) | |