Update app.py
Browse files
app.py
CHANGED
@@ -182,25 +182,35 @@ def preprocess_image(input_image: Image.Image) -> torch.Tensor:
|
|
182 |
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
183 |
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
184 |
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
185 |
-
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
186 |
-
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
input_ids = torch.cat([
|
196 |
-
|
197 |
torch.zeros((1, image_features.shape[1]), dtype=torch.long),
|
198 |
-
|
199 |
-
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
200 |
], dim=1).to('cuda')
|
201 |
attention_mask = torch.ones_like(input_ids)
|
202 |
|
203 |
-
generate_ids = text_model.generate(input_ids, inputs_embeds=
|
204 |
|
205 |
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
206 |
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
@@ -476,9 +486,9 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
476 |
)
|
477 |
|
478 |
with gr.Row():
|
479 |
-
username = gr.Textbox(label="Username", placeholder="Enter your username")
|
480 |
with gr.Row():
|
481 |
-
password = gr.Textbox(label="Password", type="password", placeholder="Enter your password")
|
482 |
with gr.Row():
|
483 |
login_button = gr.Button("Login", size="sm")
|
484 |
login_message = gr.Markdown(visible=False)
|
@@ -558,29 +568,29 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
558 |
value="long",
|
559 |
)
|
560 |
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
|
585 |
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
586 |
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|
|
|
182 |
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
183 |
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
184 |
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
|
|
|
|
185 |
|
186 |
+
convo = [
|
187 |
+
{"role": "system", "content": "You are a helpful image captioner."},
|
188 |
+
{"role": "user", "content": prompt_str},
|
189 |
+
]
|
190 |
+
convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
191 |
+
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
|
192 |
+
convo_tokens = convo_tokens.squeeze(0)
|
193 |
+
|
194 |
+
eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
|
195 |
+
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
196 |
+
preamble_len = eot_id_indices[1] - prompt.shape[1]
|
197 |
+
|
198 |
+
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to('cuda'))
|
199 |
+
|
200 |
+
input_embeds = torch.cat([
|
201 |
+
convo_embeds[:, :preamble_len],
|
202 |
+
image_features.to(dtype=convo_embeds.dtype),
|
203 |
+
convo_embeds[:, preamble_len:],
|
204 |
+
], dim=1).to('cuda')
|
205 |
|
206 |
input_ids = torch.cat([
|
207 |
+
convo_tokens[:preamble_len].unsqueeze(0),
|
208 |
torch.zeros((1, image_features.shape[1]), dtype=torch.long),
|
209 |
+
convo_tokens[preamble_len:].unsqueeze(0),
|
|
|
210 |
], dim=1).to('cuda')
|
211 |
attention_mask = torch.ones_like(input_ids)
|
212 |
|
213 |
+
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
|
214 |
|
215 |
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
216 |
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
|
|
486 |
)
|
487 |
|
488 |
with gr.Row():
|
489 |
+
username = gr.Textbox(label="Username", placeholder="Enter your username", value="ugd")
|
490 |
with gr.Row():
|
491 |
+
password = gr.Textbox(label="Password", type="password", placeholder="Enter your password", value="ugd!")
|
492 |
with gr.Row():
|
493 |
login_button = gr.Button("Login", size="sm")
|
494 |
login_message = gr.Markdown(visible=False)
|
|
|
568 |
value="long",
|
569 |
)
|
570 |
|
571 |
+
with gr.Accordion("Extra Options", open=True):
|
572 |
+
extra_options = gr.CheckboxGroup(
|
573 |
+
choices=[
|
574 |
+
"If there is a person/character in the image you must refer to them as {name}.",
|
575 |
+
"Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
|
576 |
+
"Include information about lighting.",
|
577 |
+
"Include information about camera angle.",
|
578 |
+
"Include information about whether there is a watermark or not.",
|
579 |
+
"Include information about whether there are JPEG artifacts or not.",
|
580 |
+
"If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
|
581 |
+
"Do NOT include anything sexual; keep it PG.",
|
582 |
+
"Do NOT mention the image's resolution.",
|
583 |
+
"You MUST include information about the subjective aesthetic quality of the image from low to very high.",
|
584 |
+
"Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
|
585 |
+
"Do NOT mention any text that is in the image.",
|
586 |
+
"Specify the depth of field and whether the background is in focus or blurred.",
|
587 |
+
"If applicable, mention the likely use of artificial or natural lighting sources.",
|
588 |
+
"Do NOT use any ambiguous language.",
|
589 |
+
"Include whether the image is sfw, suggestive, or nsfw.",
|
590 |
+
"ONLY describe the most important elements of the image."
|
591 |
+
],
|
592 |
+
label="Select Extra Options"
|
593 |
+
)
|
594 |
|
595 |
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
596 |
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|