patrickbdevaney commited on
Commit
87f446f
·
verified ·
1 Parent(s): 1716c9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -30
app.py CHANGED
@@ -41,35 +41,45 @@ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
43
 
44
- bfl_repo = 'black-forest-labs/FLUX.1-schnell'
45
- revision = 'refs/pr/1'
46
-
47
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
48
- text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
49
- tokenizer_clip = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
50
- text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
51
- tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
52
- vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
53
- transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
54
-
55
- quantize(transformer, weights=qfloat8)
56
- freeze(transformer)
57
- quantize(text_encoder_2, weights=qfloat8)
58
- freeze(text_encoder_2)
59
-
60
- pipe = FluxPipeline(
61
- scheduler=scheduler,
62
- text_encoder=text_encoder,
63
- tokenizer=tokenizer_clip,
64
- text_encoder_2=None,
65
- tokenizer_2=tokenizer_2,
66
- vae=vae,
67
- transformer=None,
68
- )
69
- pipe.text_encoder_2 = text_encoder_2
70
- pipe.transformer = transformer
71
- pipe.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
72
 
 
73
  print("Models and checkpoints preloaded.")
74
 
75
  def generate_description_prompt(subject, user_prompt, text_generator):
@@ -122,7 +132,7 @@ def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_ite
122
  return list(parsed_descriptions_queue)
123
 
124
  @spaces.GPU(duration=120)
125
- def generate_images(parsed_descriptions, max_iterations=2): # Set max_iterations to 1
126
  if len(parsed_descriptions) < MAX_IMAGES:
127
  prompts = parsed_descriptions
128
  else:
@@ -152,4 +162,4 @@ if __name__ == '__main__':
152
  allow_flagging='never' # Disable flagging
153
  )
154
 
155
- interface.launch(share=True)
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
43
 
44
+ def initialize_diffusers():
45
+ from optimum.quanto import freeze, qfloat8, quantize
46
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
47
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
48
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
49
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
50
+
51
+ bfl_repo = 'black-forest-labs/FLUX.1-schnell'
52
+ revision = 'refs/pr/1'
53
+
54
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision)
55
+ text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
56
+ tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype)
57
+ text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision)
58
+ tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision)
59
+ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision)
60
+ transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision)
61
+
62
+ quantize(transformer, weights=qfloat8)
63
+ freeze(transformer)
64
+ quantize(text_encoder_2, weights=qfloat8)
65
+ freeze(text_encoder_2)
66
+
67
+ pipe = FluxPipeline(
68
+ scheduler=scheduler,
69
+ text_encoder=text_encoder,
70
+ tokenizer=tokenizer,
71
+ text_encoder_2=None,
72
+ tokenizer_2=tokenizer_2,
73
+ vae=vae,
74
+ transformer=None,
75
+ )
76
+ pipe.text_encoder_2 = text_encoder_2
77
+ pipe.transformer = transformer
78
+ pipe.enable_model_cpu_offload()
79
+
80
+ return pipe
81
 
82
+ pipe = initialize_diffusers()
83
  print("Models and checkpoints preloaded.")
84
 
85
  def generate_description_prompt(subject, user_prompt, text_generator):
 
132
  return list(parsed_descriptions_queue)
133
 
134
  @spaces.GPU(duration=120)
135
+ def generate_images(parsed_descriptions, max_iterations=1): # Set max_iterations to 1
136
  if len(parsed_descriptions) < MAX_IMAGES:
137
  prompts = parsed_descriptions
138
  else:
 
162
  allow_flagging='never' # Disable flagging
163
  )
164
 
165
+ interface.launch(share=True)