cogen / dataset.py
momergul
Initial commit
8133f69
from transformers import AutoProcessor
from PIL import Image
import os
import torch
import pickle
## ACTUAL INPUT CONSTRUCTION
BASE_SPEAKER_LEN = 787
def joint_listener_input(processor, context_images, description, device):
# Preliminaries
img_dir = "tangram_pngs"
raw_images = process_images(img_dir, context_images)
target_anno = description.lower()
prompt = construct_listener_full_prompt(
processor, target_anno, 0, "verbose_instruction"
)
# Listener processing
outputs = processor(
text=[prompt],
images=[raw_images],
return_tensors="pt"
).to(device)
l_input_tokens = outputs['input_ids'][:, :-2]
l_attn_mask = outputs['attention_mask'][:, :-2]
l_attn_mask[(l_input_tokens == 0).bool()] = 0
images = outputs['pixel_values']
l_image_attn_mask = outputs['pixel_attention_mask']
# Speaker processing
prompts = []
for i in range(10):
prompt = construct_speaker_full_prompt(processor, description, i, "information_after")
prompts.append(prompt)
outputs = processor(
text=prompts,
images=[raw_images]*10,
padding='longest',
return_tensors="pt"
).to(device)
s_input_tokens = outputs['input_ids'][:, :-1]
s_attn_mask = outputs['attention_mask'][:, :-1]
s_attn_mask[(s_input_tokens == 0).bool()] = 0
s_image_attn_mask = outputs['pixel_attention_mask']
s_target_tokens = outputs['input_ids'][:, 1:]
s_target_mask = []
for i in range(10):
curr_mask = create_speaker_caption_mask(outputs['input_ids'][i], s_attn_mask[i])
s_target_mask.append(curr_mask)
s_target_mask = torch.stack(s_target_mask, dim=0)
return images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens.unsqueeze(0), \
s_attn_mask.unsqueeze(0), s_image_attn_mask.unsqueeze(0), s_target_mask.unsqueeze(0), \
s_target_tokens.unsqueeze(0)
def joint_speaker_input(processor, image_paths, target_path, device):
# Get the prompt
img_dir = "tangram_pngs"
raw_images = process_images(img_dir, image_paths)
target_idx = image_paths.index(target_path)
base_prompt = construct_speaker_base_prompt(processor, target_idx, "information_after", process=True)
# Create the basic input
outputs = processor(
text=[base_prompt],
images=[raw_images],
return_tensors="pt"
).to(device)
input_tokens = outputs['input_ids']
attn_mask = outputs['attention_mask']
attn_mask[(input_tokens == 0).bool()] = 0
images = outputs['pixel_values']
image_attn_mask = outputs['pixel_attention_mask']
return input_tokens, attn_mask, images, image_attn_mask, torch.LongTensor([target_idx]).to(device)
## UTILITIES
def get_processor():
checkpoint = "HuggingFaceM4/idefics2-8b"
processor = AutoProcessor.from_pretrained(checkpoint, do_image_splitting=False,
size={"longest_edge": 448, "shortest_edge": 224})
return processor
def get_index_to_token():
index_to_token_path = "index_to_token.pkl"
with open(index_to_token_path, 'rb') as f:
index_to_token = pickle.load(f)
return index_to_token
def process_images(img_dir, context_images):
raw_images = []
for img in context_images:
image_path = os.path.join(img_dir, img)
raw_image = Image.open(image_path).convert('RGB')
raw_images.append(raw_image)
return raw_images
def create_speaker_caption_mask(all_token_ids, text_mask):
# Overall token comp: pad + base + caption
padding_tokens = torch.sum(all_token_ids == 0).item()
caption_tokens = all_token_ids.shape[0] - (padding_tokens + BASE_SPEAKER_LEN)
# Construct a mask where the last caption tokens are 1
target_mask = torch.zeros_like(text_mask)
target_mask[-caption_tokens:] = 1
return target_mask.bool()
def construct_listener_full_prompt(processor, target_anno, target_idx, comprehension_prompt_type="verbose_instruction"):
target_anno = target_anno.lower().strip()
messages = []
if comprehension_prompt_type == "verbose_instruction":
# User side: Intro
messages.append(
{
"role" : "user",
"content" : [
{"type" : "text", "text" : "You will be presented with a sequence of 10 images and a caption describing exactly one of them. "},
{"type" : "text", "text" : "Your task is to guess which image the caption describes. "},
]
}
)
# User side: Images
for i in range(10):
if i == 0:
messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "})
else:
messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "})
messages[0]["content"].append({"type" : "image"})
# User side: Caption
messages[0]["content"].append({"type" : "text", "text" : f". Caption: {target_anno}"})
messages[0]["content"].append({"type" : "text", "text" : f" Does this caption describe Image 0, 1, 2, 3, 4, 5, 6, 7, 8 or 9?"})
# Model side: Guess
messages.append(
{
"role" : "assistant",
"content" : [
{"type" : "text", "text" : f"The caption describes Image {target_idx}"}
]
}
)
else:
assert(False)
return processor.apply_chat_template(messages, add_generation_prompt=False).strip()
def construct_speaker_full_prompt(processor, target_anno, target_idx,
generation_prompt_type="information_after"):
messages = construct_speaker_base_prompt(processor, target_idx, generation_prompt_type)
# Assistant response
target_anno = target_anno.lower().strip()
messages.append(
{
"role" : "assistant",
"content" : [
{"type" : "text", "text" : target_anno}
]
}
)
return processor.apply_chat_template(messages, add_generation_prompt=False).strip()
def construct_speaker_base_prompt(processor, target_idx, generation_prompt_type="information_after", process=False):
messages = []
if generation_prompt_type == "information_after":
# User side: Intro
messages.append(
{
"role" : "user",
"content" : [
{"type" : "text", "text" : "You will be presented with a sequence of 10 images and be assigned a target image. "},
{"type" : "text", "text" : "Your task is to produce a caption for your target image such that anyone could guess the image from your description. "},
]
}
)
# User side: Images
for i in range(10):
if i == 0:
messages[0]["content"].append({"type" : "text", "text" : f" Image {i}: "})
else:
messages[0]["content"].append({"type" : "text", "text" : f", Image {i}: "})
messages[0]["content"].append({"type" : "image"})
# User side: Target assignment
messages[0]["content"].append({"type" : "text", "text" : f". Your target image is Image {target_idx}. Produce your caption now."})
else:
assert(False)
if process:
prompt = processor.apply_chat_template(messages, add_generation_prompt=True).strip()
return prompt
else:
return messages
def process_idefics_listener_generation_input(speaker_context, captions, processor, img_dir, num_samples, device):
# First construct the prompts
prompts, raw_images = get_listener_generation_prompts(speaker_context, captions, num_samples, img_dir, processor)
# Process the prompts
listener_inputs = processor(
text=prompts,
images=raw_images,
padding='longest',
return_tensors='pt'
)
input_tokens = listener_inputs['input_ids'][:, :-2].to(device)
attn_mask = listener_inputs['attention_mask'][:, :-2].to(device)
attn_mask[input_tokens == 0] = 0
images = listener_inputs['pixel_values'].to(device)
image_attn_mask = listener_inputs['pixel_attention_mask'].to(device)
return input_tokens, attn_mask, images, image_attn_mask
def get_listener_generation_prompts(speaker_contexts, captions, num_samples, img_dir, processor):
prompts = []
all_raw_images = []
for i, speaker_context in enumerate(speaker_contexts):
raw_images = process_images(img_dir, speaker_context)
for j in range(num_samples):
curr_idx = i * num_samples + j
caption = captions[curr_idx]
prompt = construct_listener_full_prompt(processor, caption, 0, "verbose_instruction")
prompts.append(prompt)
all_raw_images.append(raw_images)
return prompts, all_raw_images