Fonte / app.py
patrickbdevaney's picture
load model before gpu spaces invoke
1716c9d verified
raw
history blame
6.22 kB
import torch.multiprocessing as mp
import torch
import os
import re
import random
from collections import deque
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from accelerate import Accelerator
import spaces
# Check if the start method has already been set
if mp.get_start_method(allow_none=True) != 'spawn':
mp.set_start_method('spawn')
# Instantiate the Accelerator
accelerator = Accelerator()
dtype = torch.bfloat16
# Set environment variables for local path
os.environ['FLUX_DEV'] = '.'
os.environ['AE'] = '.'
# Seed words pool
seed_words = []
used_words = set()
# Queue to store parsed descriptions
parsed_descriptions_queue = deque()
# Usage limits
MAX_DESCRIPTIONS = 30
MAX_IMAGES = 1 # Generate only 1 image
# Preload models and checkpoints
print("Preloading models and checkpoints...")
model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
bfl_repo = 'black-forest-labs/FLUX.1-schnell'
revision = 'refs/pr/1'
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
tokenizer_clip = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
quantize(transformer, weights=qfloat8)
freeze(transformer)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer_clip,
text_encoder_2=None,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=None,
)
pipe.text_encoder_2 = text_encoder_2
pipe.transformer = transformer
pipe.enable_model_cpu_offload()
print("Models and checkpoints preloaded.")
def generate_description_prompt(subject, user_prompt, text_generator):
prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
try:
generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text']
generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text
return generated_description if generated_description else None
except Exception as e:
print(f"Error generating description for subject '{subject}': {e}")
return None
def parse_descriptions(text):
descriptions = re.findall(r'\[([^\[\]]+)\]', text)
descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3]
return descriptions
@spaces.GPU
def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=1): # Set max_iterations to 1
descriptions = []
description_queue = deque()
iteration_count = 0
seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
for _ in range(2): # Perform two iterations
while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS:
available_subjects = [word for word in seed_words if word not in used_words]
if not available_subjects:
print("No more available subjects to use.")
break
subject = random.choice(available_subjects)
generated_description = generate_description_prompt(subject, user_prompt, text_generator)
if generated_description:
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
description_queue.append({'subject': subject, 'description': clean_description})
print(f"Generated description for subject '{subject}': {clean_description}")
used_words.add(subject)
seed_words.append(clean_description)
parsed_descriptions = parse_descriptions(clean_description)
parsed_descriptions_queue.extend(parsed_descriptions)
iteration_count += 1
return list(parsed_descriptions_queue)
@spaces.GPU(duration=120)
def generate_images(parsed_descriptions, max_iterations=2): # Set max_iterations to 1
if len(parsed_descriptions) < MAX_IMAGES:
prompts = parsed_descriptions
else:
prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)]
images = []
for prompt in prompts:
images.extend(pipe(prompt, num_images=1, num_inference_steps=max_iterations, height=512, width=512).images) # Set resolution to 512 x 512
return images
def combined_function(user_prompt, seed_words_input):
parsed_descriptions = generate_descriptions(user_prompt, seed_words_input)
images = generate_images(parsed_descriptions)
return parsed_descriptions, images
if __name__ == '__main__':
def generate_and_display(user_prompt, seed_words_input):
parsed_descriptions, images = combined_function(user_prompt, seed_words_input)
return parsed_descriptions, images
interface = gr.Interface(
fn=generate_and_display,
inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')],
outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")],
live=False, # Set live to False
allow_flagging='never' # Disable flagging
)
interface.launch(share=True)