Spaces:
Sleeping
Sleeping
File size: 7,063 Bytes
4e91b3f e66b780 4e91b3f 4122838 4e91b3f 4387c36 4e91b3f 567999c 519d719 567999c 1716c9d 87f446f 1716c9d 87f446f 1716c9d 567999c 4e91b3f 519d719 4e91b3f d5f6733 4e91b3f 2843c6f 4e91b3f 2843c6f 4e91b3f 2843c6f 4e91b3f 2843c6f 4e91b3f 2843c6f 4e91b3f 1716c9d 519d719 7fbcaaf 4e91b3f 7fbcaaf 4e91b3f 519d719 e66b780 519d719 5f554bd 8b7ec26 519d719 5f554bd 8b7ec26 776a3ec 292be6f 25d995e 2beec1e 519d719 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch.multiprocessing as mp
import torch
import os
import re
import random
from collections import deque
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
import spaces
# Check if the start method has already been set
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
# Instantiate the Accelerator
accelerator = Accelerator()
dtype = torch.bfloat16
# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = 'false' # Disable HF_HUB_ENABLE_HF_TRANSFER
# Seed words pool
seed_words = []
used_words = set()
# Queue to store parsed descriptions
parsed_descriptions_queue = deque()
# Usage limits
MAX_DESCRIPTIONS = 30
MAX_IMAGES = 3 # Limit to 3 images
# Preload models and checkpoints
print("Preloading models and checkpoints...")
model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
def initialize_diffusers():
from optimum.quanto import freeze, qfloat8, quantize
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
bfl_repo = 'black-forest-labs/FLUX.1-schnell'
revision = 'refs/pr/1'
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
quantize(transformer, weights=qfloat8)
freeze(transformer)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
return pipe
pipe = initialize_diffusers()
print("Models and checkpoints preloaded.")
def generate_description_prompt(subject, user_prompt, text_generator):
prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
try:
generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
return generated_description if generated_description else None
except Exception as e:
print(f"Error generating description for subject '{subject}': {e}")
return None
def parse_descriptions(text):
descriptions = re.findall(r'\[([^\[\]]+)\]', text)
descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
return descriptions
def format_descriptions(descriptions):
formatted_descriptions = "\n".join(descriptions)
return formatted_descriptions
@spaces.GPU
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1): # Set max_iterations to 1
descriptions = []
description_queue = deque()
iteration_count = 0
seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
for _ in range(2): # Perform two iterations
while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
available_subjects = [word for word in seed_words if word not in used_words]
if not available_subjects:
print("No more available subjects to use.")
break
subject = random.choice(available_subjects)
generated_description = generate_description_prompt(subject, user_prompt, text_generator)
if generated_description:
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
description_queue.append({'subject': subject, 'description': clean_description})
print(f"Generated description for subject '{subject}': {clean_description}")
used_words.add(subject)
seed_words.append(clean_description)
parsed_descriptions = parse_descriptions(clean_description)
parsed_descriptions_queue.extend(parsed_descriptions)
iteration_count += 1
return list(parsed_descriptions_queue)
@spaces.GPU(duration=120)
def generate_images(parsed_descriptions, max_iterations=3): # Set max_iterations to 3
# Limit the number of descriptions passed to the image generator to 2
if len(parsed_descriptions) > MAX_IMAGES:
parsed_descriptions = parsed_descriptions[:MAX_IMAGES]
images = []
for prompt in parsed_descriptions:
images.extend(pipe(prompt, num_inference_steps=max_iterations, height=512, width=512).images) # Set resolution to 512 x 512
return images
def combined_function(user_prompt, seed_words_input):
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
formatted_descriptions = format_descriptions(parsed_descriptions)
images = generate_images(parsed_descriptions)
return formatted_descriptions, images
if __name__ == '__main__':
def generate_and_display(user_prompt, seed_words_input):
formatted_descriptions, images = combined_function(user_prompt, seed_words_input)
return formatted_descriptions, images
interface = gr.Interface(
fn=generate_and_display,
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
live=False, # Set live to False
allow_flagging='never' # Disable flagging
)
interface.launch(share=True)
|