patrickbdevaney commited on
Commit
1716c9d
·
verified ·
1 Parent(s): d5f6733

load model before gpu spaces invoke

Browse files
Files changed (1) hide show
  1. app.py +39 -48
app.py CHANGED
@@ -34,43 +34,43 @@ parsed_descriptions_queue = deque()
34
  MAX_DESCRIPTIONS = 30
35
  MAX_IMAGES = 1 # Generate only 1 image
36
 
37
- def initialize_diffusers():
38
- from optimum.quanto import freeze, qfloat8, quantize
39
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
40
- from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
41
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
42
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
43
-
44
- bfl_repo = 'black-forest-labs/FLUX.1-schnell'
45
- revision = 'refs/pr/1'
46
-
47
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
48
- text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
49
- tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
50
- text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
51
- tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
52
- vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
53
- transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
54
-
55
- quantize(transformer, weights=qfloat8)
56
- freeze(transformer)
57
- quantize(text_encoder_2, weights=qfloat8)
58
- freeze(text_encoder_2)
59
-
60
- pipe = FluxPipeline(
61
- scheduler=scheduler,
62
- text_encoder=text_encoder,
63
- tokenizer=tokenizer,
64
- text_encoder_2=None,
65
- tokenizer_2=tokenizer_2,
66
- vae=vae,
67
- transformer=None,
68
- )
69
- pipe.text_encoder_2 = text_encoder_2
70
- pipe.transformer = transformer
71
- pipe.enable_model_cpu_offload()
72
-
73
- return pipe
74
 
75
  def generate_description_prompt(subject, user_prompt, text_generator):
76
  prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
@@ -93,13 +93,6 @@ def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_ite
93
  description_queue = deque()
94
  iteration_count = 0
95
 
96
- print("Initializing the text generation pipeline with 16-bit precision...")
97
- model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
98
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
99
- tokenizer = AutoTokenizer.from_pretrained(model_name)
100
- text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
101
- print("Text generation pipeline initialized with 16-bit precision.")
102
-
103
  seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
104
 
105
  for _ in range(2): # Perform two iterations
@@ -128,10 +121,8 @@ def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_ite
128
 
129
  return list(parsed_descriptions_queue)
130
 
131
- @spaces.GPU
132
  def generate_images(parsed_descriptions, max_iterations=2): # Set max_iterations to 1
133
- pipe = initialize_diffusers()
134
-
135
  if len(parsed_descriptions) < MAX_IMAGES:
136
  prompts = parsed_descriptions
137
  else:
@@ -161,4 +152,4 @@ if __name__ == '__main__':
161
  allow_flagging='never' # Disable flagging
162
  )
163
 
164
- interface.launch(share=True)
 
34
  MAX_DESCRIPTIONS = 30
35
  MAX_IMAGES = 1 # Generate only 1 image
36
 
37
+ # Preload models and checkpoints
38
+ print("Preloading models and checkpoints...")
39
+ model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct'
40
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto')
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
43
+
44
+ bfl_repo = 'black-forest-labs/FLUX.1-schnell'
45
+ revision = 'refs/pr/1'
46
+
47
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
48
+ text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
49
+ tokenizer_clip = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
50
+ text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
51
+ tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
52
+ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
53
+ transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
54
+
55
+ quantize(transformer, weights=qfloat8)
56
+ freeze(transformer)
57
+ quantize(text_encoder_2, weights=qfloat8)
58
+ freeze(text_encoder_2)
59
+
60
+ pipe = FluxPipeline(
61
+ scheduler=scheduler,
62
+ text_encoder=text_encoder,
63
+ tokenizer=tokenizer_clip,
64
+ text_encoder_2=None,
65
+ tokenizer_2=tokenizer_2,
66
+ vae=vae,
67
+ transformer=None,
68
+ )
69
+ pipe.text_encoder_2 = text_encoder_2
70
+ pipe.transformer = transformer
71
+ pipe.enable_model_cpu_offload()
72
+
73
+ print("Models and checkpoints preloaded.")
74
 
75
  def generate_description_prompt(subject, user_prompt, text_generator):
76
  prompt = f"write concise vivid visual description enclosed in brackets like [ <description> ] less than 100 words of {user_prompt} different from {subject}. "
 
93
  description_queue = deque()
94
  iteration_count = 0
95
 
 
 
 
 
 
 
 
96
  seed_words.extend(re.findall(r'"(.*?)"', seed_words_input))
97
 
98
  for _ in range(2): # Perform two iterations
 
121
 
122
  return list(parsed_descriptions_queue)
123
 
124
+ @spaces.GPU(duration=120)
125
  def generate_images(parsed_descriptions, max_iterations=2): # Set max_iterations to 1
 
 
126
  if len(parsed_descriptions) < MAX_IMAGES:
127
  prompts = parsed_descriptions
128
  else:
 
152
  allow_flagging='never' # Disable flagging
153
  )
154
 
155
+ interface.launch(share=True)