Spaces:
Sleeping
Sleeping
File size: 6,539 Bytes
4e91b3f e66b780 4e91b3f 4122838 4e91b3f 4387c36 4e91b3f 567999c a5d5686 567999c 1716c9d 87f446f 1716c9d 87f446f 1716c9d 567999c 8033423 4e91b3f 8033423 a5d5686 4e91b3f 8033423 4e91b3f 519d719 4e91b3f 3445324 a5d5686 3445324 a5d5686 8033423 4e91b3f 1716c9d a5d5686 7fbcaaf 4e91b3f 7fbcaaf 8033423 a5d5686 8033423 4e91b3f 519d719 e66b780 519d719 5f554bd 8b7ec26 519d719 5f554bd 8b7ec26 8033423 776a3ec 292be6f 25d995e 2beec1e a5d5686 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import torch.multiprocessing as mp
import torch
import os
import re
import random
from collections import deque
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
import spaces
# Check if the start method has already been set
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
# Instantiate the Accelerator
accelerator = Accelerator()
dtype = torch.bfloat16
# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'false' # Disable HF_HUB_ENABLE_HF_TRANSFER
# Seed words pool
seed_words = []
used_words = set()
# Queue to store parsed descriptions
parsed_descriptions_queue = deque()
# Usage limits
MAX_DESCRIPTIONS = 30
MAX_IMAGES = 4 # Limit to 4 images
# Preload models and checkpoints
print("Preloading models and checkpoints...")
model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
def initialize_diffusers():
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
bfl_repo = 'black-forest-labs/FLUX.1-schnell'
revision = 'refs/pr/1'
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
quantize(transformer, weights=qfloat8)
freeze(transformer)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
return pipe
pipe = initialize_diffusers()
print("Models and checkpoints preloaded.")
def generate_description_prompt(user_prompt, text_generator):
injected_prompt = f"write three concise descriptions enclosed in brackets like [ <description> ] less than 100 words each of {user_prompt}. "
max_length = 110 # Set max token length to 110
try:
generated_text = text_generator(injected_prompt, max_length=max_length, num_return_sequences=1, truncation=True)[0]['generated_text']
generated_descriptions = re.findall(r'\[([^\[\]]+)\]', generated_text) # Extract descriptions enclosed in brackets
# Filter descriptions to ensure they are at least 4 words long
filtered_descriptions = [desc for desc in generated_descriptions if len(desc.split()) >= 4]
return filtered_descriptions if filtered_descriptions else None
except Exception as e:
print(f"Error generating descriptions: {e}")
return None
def format_descriptions(descriptions):
formatted_descriptions = "\n".join(descriptions)
return formatted_descriptions
@spaces.GPU
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=4): # Set max_iterations to 4
descriptions = []
for _ in range(4): # Perform four iterations
new_descriptions = generate_description_prompt(user_prompt, text_generator)
if new_descriptions:
descriptions.extend(new_descriptions)
# Pick a random description to feed back into the seed bank for subject
random_description = random.choice(new_descriptions)
seed_words.append(random_description)
# Limit the number of descriptions to MAX_IMAGES (4)
if len(descriptions) > MAX_IMAGES:
descriptions = descriptions[:MAX_IMAGES]
parsed_descriptions_queue.extend(descriptions)
return list(parsed_descriptions_queue)[:MAX_IMAGES]
@spaces.GPU(duration=120)
def generate_images(parsed_descriptions, max_iterations=4): # Set max_iterations to 4
# Limit the number of descriptions passed to the image generator to MAX_IMAGES (4)
if len(parsed_descriptions) > MAX_IMAGES:
parsed_descriptions = parsed_descriptions[:MAX_IMAGES]
images = []
for prompt in parsed_descriptions:
try:
images.extend(pipe(prompt, num_inference_steps=4, height=1024, width=1024).images) # Set resolution to 1024 x 1024
except Exception as e:
print(f"Error generating image for prompt '{prompt}': {e}")
return images
def combined_function(user_prompt, seed_words_input):
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
formatted_descriptions = format_descriptions(parsed_descriptions)
images = generate_images(parsed_descriptions)
return formatted_descriptions, images
if __name__ == '__main__':
def generate_and_display(user_prompt, seed_words_input):
formatted_descriptions, images = combined_function(user_prompt, seed_words_input)
return formatted_descriptions, images
interface = gr.Interface(
fn=generate_and_display,
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter example in quotes, e.g., "cat", "dog", "sunset"...')],
outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
live=False, # Set live to False
allow_flagging='never' # Disable flagging
)
interface.launch(share=True)
|