Spaces:
Sleeping
Sleeping
import logging | |
import random | |
import string | |
from dataclasses import dataclass | |
from typing import Any, Optional, Union | |
from transformers import PreTrainedTokenizerBase | |
from transformers.utils import PaddingStrategy | |
logger = logging.getLogger(__name__) | |
class DataCollatorForNI: | |
tokenizer: PreTrainedTokenizerBase | |
model: Optional[Any] = None | |
padding: Union[bool, str, PaddingStrategy] = True | |
max_source_length: Optional[int] = None | |
max_target_length: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
label_pad_token_id: int = -100 | |
return_tensors: str = "pt" | |
add_task_name: bool = False | |
add_task_definition: bool = True | |
num_pos_examples: int = 0 | |
num_neg_examples: int = 0 | |
add_explanation: bool = False | |
tk_instruct: bool = False | |
text_only: bool = False | |
random_gen: random.Random = random.Random(42) | |
def __call__(self, batch, return_tensors=None): | |
if return_tensors is None: | |
return_tensors = self.return_tensors | |
sources = [] | |
batch = [batch] | |
for instance in batch: | |
if self.tk_instruct: | |
all_valid_encodings = [ | |
# instruction only | |
{ | |
"add_task_name": False, | |
"add_task_definition": True, | |
"num_pos_examples": 0, | |
"num_neg_examples": 0, | |
"add_explanation": False, | |
}, | |
# example only | |
{ | |
"add_task_name": False, | |
"add_task_definition": False, | |
"num_pos_examples": 2, | |
"num_neg_examples": 0, | |
"add_explanation": False, | |
}, | |
# instruction + pos examples | |
{ | |
"add_task_name": False, | |
"add_task_definition": True, | |
"num_pos_examples": 2, | |
"num_neg_examples": 0, | |
"add_explanation": False, | |
}, | |
# instruction + pos examples + neg examples | |
{ | |
"add_task_name": False, | |
"add_task_definition": True, | |
"num_pos_examples": 2, | |
"num_neg_examples": 2, | |
"add_explanation": False, | |
}, | |
# instruction + pos (w. explanation) | |
{ | |
"add_task_name": False, | |
"add_task_definition": True, | |
"num_pos_examples": 2, | |
"num_neg_examples": 0, | |
"add_explanation": True, | |
}, | |
] | |
encoding_schema = self.random_gen.choice(all_valid_encodings) | |
add_task_name = encoding_schema["add_task_name"] | |
add_task_definition = encoding_schema["add_task_definition"] | |
num_pos_examples = encoding_schema["num_pos_examples"] | |
num_neg_examples = encoding_schema["num_neg_examples"] | |
add_explanation = encoding_schema["add_explanation"] | |
else: | |
add_task_name = self.add_task_name | |
add_task_definition = self.add_task_definition | |
num_pos_examples = self.num_pos_examples | |
num_neg_examples = self.num_neg_examples | |
add_explanation = self.add_explanation | |
task_input = "" | |
# add the input first. | |
task_input += "Now complete the following example -\n" | |
task_input += f"Input: {instance['Instance']['input'].strip()}" | |
if not task_input[-1] in string.punctuation: | |
task_input += "." | |
task_input += "\n" | |
task_input += "Output: " | |
task_name = "" | |
if add_task_name: | |
task_name += instance["Task"] + ". " | |
definition = "" | |
if add_task_definition: | |
if isinstance(instance["Definition"], list): | |
definition = ( | |
"Definition: " + instance["Definition"][0].strip() | |
) # TODO: should we use <Definition>? | |
else: | |
definition = "Definition: " + instance["Definition"].strip() | |
if not definition[-1] in string.punctuation: | |
definition += "." | |
definition += "\n\n" | |
# try to add positive examples. | |
pos_examples = [] | |
for idx, pos_example in enumerate( | |
instance["Positive Examples"][:num_pos_examples] | |
): | |
pos_example_str = f" Positive Example {idx+1} -\n" | |
pos_example_str += f"Input: {pos_example['input'].strip()}" | |
if not pos_example_str[-1] in string.punctuation: | |
pos_example_str += "." | |
pos_example_str += "\n" | |
pos_example_str += f" Output: {pos_example['output'].strip()}" | |
if not pos_example_str[-1] in string.punctuation: | |
pos_example_str += "." | |
pos_example_str += "\n" | |
if add_explanation and "explanation" in pos_example: | |
pos_example_str += ( | |
f" Explanation: {pos_example['explanation'].strip()}" | |
) | |
if not pos_example_str[-1] in string.punctuation: | |
pos_example_str += "." | |
pos_example_str += "\n" | |
pos_example_str += "\n" | |
if ( | |
len( | |
self.tokenizer( | |
definition | |
+ " ".join(pos_examples) | |
+ pos_example_str | |
+ task_input | |
)["input_ids"] | |
) | |
<= self.max_source_length | |
): | |
pos_examples.append(pos_example_str) | |
else: | |
break | |
# try to add negative examples. | |
neg_examples = [] | |
for idx, neg_example in enumerate( | |
instance["Negative Examples"][:num_neg_examples] | |
): | |
neg_example_str = f" Negative Example {idx+1} -\n" | |
neg_example_str += f"Input: {neg_example['input'].strip()}" | |
if not neg_example_str[-1] in string.punctuation: | |
neg_example_str += "." | |
neg_example_str += "\n" | |
neg_example_str += f" Output: {neg_example['output'].strip()}" | |
if not neg_example_str[-1] in string.punctuation: | |
neg_example_str += "." | |
neg_example_str += "\n" | |
if add_explanation and "explanation" in neg_example: | |
neg_example_str += ( | |
f" Explanation: {neg_example['explanation'].strip()}" | |
) | |
if not neg_example_str[-1] in string.punctuation: | |
neg_example_str += "." | |
neg_example_str += "\n" | |
neg_example_str += "\n" | |
if ( | |
len( | |
self.tokenizer( | |
definition | |
+ " ".join(pos_examples) | |
+ " ".join(neg_examples) | |
+ neg_example_str | |
+ task_input | |
)["input_ids"] | |
) | |
<= self.max_source_length | |
): | |
neg_examples.append(neg_example_str) | |
else: | |
break | |
source = ( | |
task_name | |
+ definition | |
+ "".join(pos_examples) | |
+ "".join(neg_examples) | |
+ task_input | |
) | |
tokenized_source = self.tokenizer(source)["input_ids"] | |
if len(tokenized_source) <= self.max_source_length: | |
sources.append(source) | |
else: | |
sources.append( | |
self.tokenizer.decode( | |
tokenized_source[: self.max_source_length], | |
skip_special_tokens=True, | |
) | |
) | |
if self.text_only: | |
model_inputs = {"inputs": sources} | |
else: | |
model_inputs = self.tokenizer( | |
sources, | |
max_length=self.max_source_length, | |
padding=self.padding, | |
return_tensors=self.return_tensors, | |
truncation=True, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]: | |
# Randomly select one reference if multiple are provided. | |
labels = [self.random_gen.choice(ex["Instance"]["output"]) for ex in batch] | |
if self.text_only: | |
model_inputs["label"] = labels | |
else: | |
with self.tokenizer.as_target_tokenizer(): | |
labels = self.tokenizer( | |
labels, | |
max_length=self.max_target_length, | |
padding=self.padding, | |
return_tensors=self.return_tensors, | |
truncation=True, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
) | |
label_mask = labels["attention_mask"].bool() | |
model_inputs["label"] = labels["input_ids"].masked_fill( | |
~label_mask, self.label_pad_token_id | |
) | |
else: | |
model_inputs["label"] = None | |
# prepare decoder_input_ids | |
if ( | |
self.model is not None | |
and hasattr(self.model, "prepare_decoder_input_ids_from_labels") | |
and not self.text_only | |
): | |
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( | |
labels=model_inputs["label"] | |
) | |
model_inputs["decoder_input_ids"] = decoder_input_ids | |
# flatten the inputs to avoid listing | |
return {k: v[0] for k, v in model_inputs.items()} | |