BullseyeMxP commited on
Commit
1a23e22
1 Parent(s): a351b6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -67
app.py CHANGED
@@ -114,20 +114,25 @@ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTr
114
  print("Loading LLM")
115
  print("Loading VLM's custom text model")
116
 
117
- # Configure 4-bit quantization
118
  bnb_config = BitsAndBytesConfig(
119
  load_in_4bit=True,
120
  bnb_4bit_quant_type="nf4",
121
  bnb_4bit_compute_dtype=torch.float16,
122
  bnb_4bit_use_double_quant=True,
 
123
  )
124
 
125
  text_model = AutoModelForCausalLM.from_pretrained(
126
  CHECKPOINT_PATH / "text_model",
127
  device_map="auto",
128
  quantization_config=bnb_config,
129
- torch_dtype=torch.float16
 
130
  )
 
 
 
131
  text_model.gradient_checkpointing_enable()
132
  text_model.eval()
133
  text_model = torch.compile(text_model)
@@ -140,15 +145,27 @@ image_adapter.eval()
140
  image_adapter.to("cuda")
141
  image_adapter = torch.compile(image_adapter)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  @spaces.GPU()
144
  @torch.no_grad()
145
  def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
 
146
  torch.cuda.empty_cache()
147
  gc.collect()
148
 
149
- # 'any' means no length specified
150
  length = None if caption_length == "any" else caption_length
151
-
152
  if isinstance(length, str):
153
  try:
154
  length = int(length)
@@ -176,57 +193,42 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
176
 
177
  if custom_prompt.strip() != "":
178
  prompt_str = custom_prompt.strip()
179
-
180
- # For debugging
181
- print(f"Prompt: {prompt_str}")
182
 
183
- # Preprocess image
184
  image = input_image.resize((384, 384), Image.LANCZOS)
185
- image = image.convert('RGB') # Ensure the image has 3 channels
186
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
187
- pixel_values = TVF.normalize(pixel_values, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Normalize for all 3 channels
188
  pixel_values = pixel_values.to('cuda', dtype=torch.float16)
189
 
190
- # Embed image
191
- with torch.amp.autocast_mode.autocast('cuda', dtype=torch.float16):
192
  vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
193
  embedded_images = image_adapter(vision_outputs.hidden_states)
194
  embedded_images = embedded_images.to('cuda', dtype=torch.float16)
195
 
196
- # Build the conversation
197
  convo = [
198
- {
199
- "role": "system",
200
- "content": "You are a helpful image captioner.",
201
- },
202
- {
203
- "role": "user",
204
- "content": prompt_str,
205
- },
206
  ]
207
 
208
- # Format the conversation
209
- convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
210
- assert isinstance(convo_string, str)
211
-
212
- # Tokenize the conversation
213
  convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
214
  prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
215
- assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
216
- convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
217
  prompt_tokens = prompt_tokens.squeeze(0)
218
 
219
- # Calculate where to inject the image
220
  eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
221
- assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
222
-
223
- preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
224
 
225
- # Embed the tokens
226
- convo_tokens = convo_tokens.unsqueeze(0).to('cuda') # Keep as LongTensor
227
  convo_embeds = text_model.model.embed_tokens(convo_tokens)
228
 
229
- # Construct the input
230
  input_embeds = torch.cat([
231
  convo_embeds[:, :preamble_len],
232
  embedded_images,
@@ -240,27 +242,31 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
240
  ], dim=1)
241
  attention_mask = torch.ones_like(input_ids)
242
 
243
- # Debugging
244
- print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
245
-
246
- with torch.amp.autocast_mode.autocast('cuda', dtype=torch.float16):
247
  generate_ids = text_model.generate(
248
  input_ids,
249
  inputs_embeds=input_embeds,
250
  attention_mask=attention_mask,
251
  max_new_tokens=300,
252
  do_sample=True,
253
- suppress_tokens=None,
254
- use_cache=True
 
 
 
 
255
  )
256
 
257
- # Trim off the prompt
258
  generate_ids = generate_ids[:, input_ids.shape[1]:]
259
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
260
  generate_ids = generate_ids[:, :-1]
261
 
262
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
263
 
 
 
264
  torch.cuda.empty_cache()
265
  gc.collect()
266
 
@@ -275,7 +281,7 @@ def process_directory(directory_path, caption_type, caption_length, extra_option
275
  img_path = os.path.join(directory_path, filename)
276
  img = Image.open(img_path)
277
 
278
- prompt, caption = stream_chat(img, caption_type, caption_length, extra_options, name_input, custom_prompt)
279
 
280
  # Save caption to a .txt file
281
  txt_filename = os.path.splitext(filename)[0] + '.txt'
@@ -284,9 +290,29 @@ def process_directory(directory_path, caption_type, caption_length, extra_option
284
  f.write(caption)
285
 
286
  processed_images.append(img_path)
287
- captions.append({"filename": filename, "caption": caption})
288
 
289
- return processed_images, captions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  # Custom CSS for a futuristic, neon-inspired theme
292
  custom_css = """
@@ -439,27 +465,7 @@ with gr.Blocks(css=custom_css) as demo:
439
 
440
  with gr.Row():
441
  output_gallery = gr.Gallery(label="Processed Images", elem_classes="output-box")
442
- output_text = gr.JSON(label="Generated Captions", elem_classes="output-box")
443
-
444
- def process_and_display(images, caption_type, caption_length, extra_options, name_input, custom_prompt):
445
- processed_images = []
446
- captions = []
447
-
448
- for img_file in images:
449
- img = Image.open(img_file.name)
450
- prompt, caption = stream_chat(img, caption_type, caption_length, extra_options, name_input, custom_prompt)
451
- processed_images.append(img_file.name)
452
- captions.append({"filename": img_file.name, "caption": caption})
453
-
454
- return processed_images, captions
455
-
456
- def process_input(input_images, directory_path, caption_type, caption_length, extra_options, name_input, custom_prompt):
457
- if directory_path:
458
- return process_directory(directory_path, caption_type, caption_length, extra_options, name_input, custom_prompt)
459
- elif input_images:
460
- return process_and_display(input_images, caption_type, caption_length, extra_options, name_input, custom_prompt)
461
- else:
462
- return [], []
463
 
464
  run_button.click(
465
  fn=process_input,
 
114
  print("Loading LLM")
115
  print("Loading VLM's custom text model")
116
 
117
+ # Configure 4-bit quantization with more aggressive settings
118
  bnb_config = BitsAndBytesConfig(
119
  load_in_4bit=True,
120
  bnb_4bit_quant_type="nf4",
121
  bnb_4bit_compute_dtype=torch.float16,
122
  bnb_4bit_use_double_quant=True,
123
+ llm_int8_enable_fp32_cpu_offload=True
124
  )
125
 
126
  text_model = AutoModelForCausalLM.from_pretrained(
127
  CHECKPOINT_PATH / "text_model",
128
  device_map="auto",
129
  quantization_config=bnb_config,
130
+ torch_dtype=torch.float16,
131
+ low_cpu_mem_usage=True
132
  )
133
+
134
+ # Enable memory efficient attention
135
+ text_model.config.use_memory_efficient_attention = True
136
  text_model.gradient_checkpointing_enable()
137
  text_model.eval()
138
  text_model = torch.compile(text_model)
 
145
  image_adapter.to("cuda")
146
  image_adapter = torch.compile(image_adapter)
147
 
148
+ # Optimize CLIP model
149
+ clip_model = clip_model.half() # Convert to FP16
150
+ clip_model.eval()
151
+ clip_model.requires_grad_(False)
152
+ clip_model = torch.compile(clip_model)
153
+
154
+ # Optimize image adapter
155
+ image_adapter = image_adapter.half() # Convert to FP16
156
+ image_adapter.eval()
157
+ image_adapter.requires_grad_(False)
158
+ image_adapter = torch.compile(image_adapter)
159
+
160
  @spaces.GPU()
161
  @torch.no_grad()
162
  def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str | int, extra_options: list[str], name_input: str, custom_prompt: str) -> tuple[str, str]:
163
+ # Clear memory at the start
164
  torch.cuda.empty_cache()
165
  gc.collect()
166
 
167
+ # Build prompt string
168
  length = None if caption_length == "any" else caption_length
 
169
  if isinstance(length, str):
170
  try:
171
  length = int(length)
 
193
 
194
  if custom_prompt.strip() != "":
195
  prompt_str = custom_prompt.strip()
 
 
 
196
 
197
+ # Resize image to exact dimensions needed
198
  image = input_image.resize((384, 384), Image.LANCZOS)
199
+ image = image.convert('RGB')
200
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
201
+ pixel_values = TVF.normalize(pixel_values, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
202
  pixel_values = pixel_values.to('cuda', dtype=torch.float16)
203
 
204
+ # Process image with optimized memory usage
205
+ with torch.amp.autocast('cuda', dtype=torch.float16):
206
  vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
207
  embedded_images = image_adapter(vision_outputs.hidden_states)
208
  embedded_images = embedded_images.to('cuda', dtype=torch.float16)
209
 
210
+ # Build the conversation with minimal overhead
211
  convo = [
212
+ {"role": "system", "content": "You are a helpful image captioner."},
213
+ {"role": "user", "content": prompt_str},
 
 
 
 
 
 
214
  ]
215
 
216
+ # Format and tokenize efficiently
217
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
 
 
 
218
  convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
219
  prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
220
+
221
+ convo_tokens = convo_tokens.squeeze(0)
222
  prompt_tokens = prompt_tokens.squeeze(0)
223
 
224
+ # Calculate injection point
225
  eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
226
+ preamble_len = eot_id_indices[1] - prompt_tokens.shape[0]
 
 
227
 
228
+ # Prepare input tensors efficiently
229
+ convo_tokens = convo_tokens.unsqueeze(0).to('cuda')
230
  convo_embeds = text_model.model.embed_tokens(convo_tokens)
231
 
 
232
  input_embeds = torch.cat([
233
  convo_embeds[:, :preamble_len],
234
  embedded_images,
 
242
  ], dim=1)
243
  attention_mask = torch.ones_like(input_ids)
244
 
245
+ # Generate with optimized settings
246
+ with torch.amp.autocast('cuda', dtype=torch.float16):
 
 
247
  generate_ids = text_model.generate(
248
  input_ids,
249
  inputs_embeds=input_embeds,
250
  attention_mask=attention_mask,
251
  max_new_tokens=300,
252
  do_sample=True,
253
+ use_cache=True,
254
+ pad_token_id=tokenizer.pad_token_id,
255
+ num_beams=1, # Disable beam search for faster generation
256
+ temperature=0.7, # Lower temperature for more focused generation
257
+ top_p=0.9, # Nucleus sampling for efficiency
258
+ repetition_penalty=1.2, # Prevent repetition
259
  )
260
 
261
+ # Process output efficiently
262
  generate_ids = generate_ids[:, input_ids.shape[1]:]
263
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
264
  generate_ids = generate_ids[:, :-1]
265
 
266
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
267
 
268
+ # Clear memory
269
+ del vision_outputs, embedded_images, input_embeds, generate_ids
270
  torch.cuda.empty_cache()
271
  gc.collect()
272
 
 
281
  img_path = os.path.join(directory_path, filename)
282
  img = Image.open(img_path)
283
 
284
+ _, caption = stream_chat(img, caption_type, caption_length, extra_options, name_input, custom_prompt)
285
 
286
  # Save caption to a .txt file
287
  txt_filename = os.path.splitext(filename)[0] + '.txt'
 
290
  f.write(caption)
291
 
292
  processed_images.append(img_path)
293
+ captions.append(caption)
294
 
295
+ return processed_images, "\n\n".join(captions) # Join captions with double newline for readability
296
+
297
+ def process_and_display(images, caption_type, caption_length, extra_options, name_input, custom_prompt):
298
+ processed_images = []
299
+ captions = []
300
+
301
+ for img_file in images:
302
+ img = Image.open(img_file.name)
303
+ _, caption = stream_chat(img, caption_type, caption_length, extra_options, name_input, custom_prompt)
304
+ processed_images.append(img_file.name)
305
+ captions.append(caption)
306
+
307
+ return processed_images, "\n\n".join(captions) # Join captions with double newline for readability
308
+
309
+ def process_input(input_images, directory_path, caption_type, caption_length, extra_options, name_input, custom_prompt):
310
+ if directory_path:
311
+ return process_directory(directory_path, caption_type, caption_length, extra_options, name_input, custom_prompt)
312
+ elif input_images:
313
+ return process_and_display(input_images, caption_type, caption_length, extra_options, name_input, custom_prompt)
314
+ else:
315
+ return [], ""
316
 
317
  # Custom CSS for a futuristic, neon-inspired theme
318
  custom_css = """
 
465
 
466
  with gr.Row():
467
  output_gallery = gr.Gallery(label="Processed Images", elem_classes="output-box")
468
+ output_text = gr.Textbox(label="Generated Captions", elem_classes="output-box", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  run_button.click(
471
  fn=process_input,