Severian commited on
Commit
4992d18
1 Parent(s): 348afd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -37
app.py CHANGED
@@ -182,25 +182,35 @@ def preprocess_image(input_image: Image.Image) -> torch.Tensor:
182
  def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
183
  prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
184
  prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
185
- embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
186
- eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
187
 
188
- inputs_embeds = torch.cat([
189
- embedded_bos.expand(image_features.shape[0], -1, -1),
190
- image_features.to(dtype=embedded_bos.dtype),
191
- prompt_embeds.expand(image_features.shape[0], -1, -1),
192
- eot_embed.expand(image_features.shape[0], -1, -1),
193
- ], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  input_ids = torch.cat([
196
- torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
197
  torch.zeros((1, image_features.shape[1]), dtype=torch.long),
198
- prompt,
199
- torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
200
  ], dim=1).to('cuda')
201
  attention_mask = torch.ones_like(input_ids)
202
 
203
- generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
204
 
205
  generate_ids = generate_ids[:, input_ids.shape[1]:]
206
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
@@ -476,9 +486,9 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
476
  )
477
 
478
  with gr.Row():
479
- username = gr.Textbox(label="Username", placeholder="Enter your username")
480
  with gr.Row():
481
- password = gr.Textbox(label="Password", type="password", placeholder="Enter your password")
482
  with gr.Row():
483
  login_button = gr.Button("Login", size="sm")
484
  login_message = gr.Markdown(visible=False)
@@ -558,29 +568,29 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
558
  value="long",
559
  )
560
 
561
- with gr.Accordion("Extra Options", open=True):
562
- extra_options = gr.CheckboxGroup(
563
- choices=[
564
- "If there is a person/character in the image you must refer to them as {name}.",
565
- "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
566
- "Include information about lighting.",
567
- "Include information about camera angle.",
568
- "Include information about whether there is a watermark or not.",
569
- "Include information about whether there are JPEG artifacts or not.",
570
- "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
571
- "Do NOT include anything sexual; keep it PG.",
572
- "Do NOT mention the image's resolution.",
573
- "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
574
- "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
575
- "Do NOT mention any text that is in the image.",
576
- "Specify the depth of field and whether the background is in focus or blurred.",
577
- "If applicable, mention the likely use of artificial or natural lighting sources.",
578
- "Do NOT use any ambiguous language.",
579
- "Include whether the image is sfw, suggestive, or nsfw.",
580
- "ONLY describe the most important elements of the image."
581
- ],
582
- label="Select Extra Options"
583
- )
584
 
585
  name_input = gr.Textbox(label="Person/Character Name (if applicable)")
586
  gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
 
182
  def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
183
  prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
184
  prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
 
 
185
 
186
+ convo = [
187
+ {"role": "system", "content": "You are a helpful image captioner."},
188
+ {"role": "user", "content": prompt_str},
189
+ ]
190
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
191
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
192
+ convo_tokens = convo_tokens.squeeze(0)
193
+
194
+ eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
195
+ assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
196
+ preamble_len = eot_id_indices[1] - prompt.shape[1]
197
+
198
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to('cuda'))
199
+
200
+ input_embeds = torch.cat([
201
+ convo_embeds[:, :preamble_len],
202
+ image_features.to(dtype=convo_embeds.dtype),
203
+ convo_embeds[:, preamble_len:],
204
+ ], dim=1).to('cuda')
205
 
206
  input_ids = torch.cat([
207
+ convo_tokens[:preamble_len].unsqueeze(0),
208
  torch.zeros((1, image_features.shape[1]), dtype=torch.long),
209
+ convo_tokens[preamble_len:].unsqueeze(0),
 
210
  ], dim=1).to('cuda')
211
  attention_mask = torch.ones_like(input_ids)
212
 
213
+ generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
214
 
215
  generate_ids = generate_ids[:, input_ids.shape[1]:]
216
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
 
486
  )
487
 
488
  with gr.Row():
489
+ username = gr.Textbox(label="Username", placeholder="Enter your username", value="ugd")
490
  with gr.Row():
491
+ password = gr.Textbox(label="Password", type="password", placeholder="Enter your password", value="ugd!")
492
  with gr.Row():
493
  login_button = gr.Button("Login", size="sm")
494
  login_message = gr.Markdown(visible=False)
 
568
  value="long",
569
  )
570
 
571
+ with gr.Accordion("Extra Options", open=True):
572
+ extra_options = gr.CheckboxGroup(
573
+ choices=[
574
+ "If there is a person/character in the image you must refer to them as {name}.",
575
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
576
+ "Include information about lighting.",
577
+ "Include information about camera angle.",
578
+ "Include information about whether there is a watermark or not.",
579
+ "Include information about whether there are JPEG artifacts or not.",
580
+ "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
581
+ "Do NOT include anything sexual; keep it PG.",
582
+ "Do NOT mention the image's resolution.",
583
+ "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
584
+ "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
585
+ "Do NOT mention any text that is in the image.",
586
+ "Specify the depth of field and whether the background is in focus or blurred.",
587
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
588
+ "Do NOT use any ambiguous language.",
589
+ "Include whether the image is sfw, suggestive, or nsfw.",
590
+ "ONLY describe the most important elements of the image."
591
+ ],
592
+ label="Select Extra Options"
593
+ )
594
 
595
  name_input = gr.Textbox(label="Person/Character Name (if applicable)")
596
  gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")