chendl's picture
Add application file
0b7b08a
import torch
import random
import numpy as np
from functools import partial
import torch.nn.functional as nnf
from torchvision import transforms as T
# A lot of the approaches here are inspired from the wonderful paper from O'Connor and Andreas 2021.
# https://github.com/lingo-mit/context-ablations
def get_text_perturb_fn(text_perturb_fn):
if text_perturb_fn == "shuffle_nouns_and_adj":
return shuffle_nouns_and_adj
elif text_perturb_fn == "shuffle_allbut_nouns_and_adj":
return shuffle_allbut_nouns_and_adj
elif text_perturb_fn == "shuffle_within_trigrams":
return shuffle_within_trigrams
elif text_perturb_fn == "shuffle_all_words":
return shuffle_all_words
elif text_perturb_fn == "shuffle_trigrams":
return shuffle_trigrams
elif text_perturb_fn is None:
return None
else:
print("Unknown text perturbation function: {}, returning None".format(text_perturb_fn))
return None
def get_image_perturb_fn(image_perturb_fn):
if image_perturb_fn == "shuffle_rows_4":
return partial(shuffle_rows, n_rows=4)
elif image_perturb_fn == "shuffle_patches_9":
return partial(shuffle_patches, n_ratio=3)
elif image_perturb_fn == "shuffle_cols_4":
return partial(shuffle_columns, n_cols=4)
elif image_perturb_fn is None:
return None
else:
print("Unknown image perturbation function: {}, returning None".format(image_perturb_fn))
return None
class TextShuffler:
def __init__(self):
import spacy
self.nlp = spacy.load("en_core_web_sm")
def shuffle_nouns_and_adj(self, ex):
doc = self.nlp(ex)
tokens = [token.text for token in doc]
text = np.array(tokens)
noun_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS']]
## Finding adjectives
adjective_idx = [i for i, token in enumerate(doc) if token.tag_ in ['JJ', 'JJR', 'JJS']]
## Shuffle the nouns of the text
text[noun_idx] = np.random.permutation(text[noun_idx])
## Shuffle the adjectives of the text
text[adjective_idx] = np.random.permutation(text[adjective_idx])
return " ".join(text)
def shuffle_all_words(self, ex):
return " ".join(np.random.permutation(ex.split(" ")))
def shuffle_allbut_nouns_and_adj(self, ex):
doc = self.nlp(ex)
tokens = [token.text for token in doc]
text = np.array(tokens)
noun_adj_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS', 'JJ', 'JJR', 'JJS']]
## Finding adjectives
else_idx = np.ones(text.shape[0])
else_idx[noun_adj_idx] = 0
else_idx = else_idx.astype(bool)
## Shuffle everything that are nouns or adjectives
text[else_idx] = np.random.permutation(text[else_idx])
return " ".join(text)
def get_trigrams(self, sentence):
# Taken from https://github.com/lingo-mit/context-ablations/blob/478fb18a9f9680321f0d37dc999ea444e9287cc0/code/transformers/src/transformers/data/data_augmentation.py
trigrams = []
trigram = []
for i in range(len(sentence)):
trigram.append(sentence[i])
if i % 3 == 2:
trigrams.append(trigram[:])
trigram = []
if trigram:
trigrams.append(trigram)
return trigrams
def trigram_shuffle(self, sentence):
trigrams = self.get_trigrams(sentence)
for trigram in trigrams:
random.shuffle(trigram)
return " ".join([" ".join(trigram) for trigram in trigrams])
def shuffle_within_trigrams(self, ex):
import nltk
tokens = nltk.word_tokenize(ex)
shuffled_ex = self.trigram_shuffle(tokens)
return shuffled_ex
def shuffle_trigrams(self, ex):
import nltk
tokens = nltk.word_tokenize(ex)
trigrams = self.get_trigrams(tokens)
random.shuffle(trigrams)
shuffled_ex = " ".join([" ".join(trigram) for trigram in trigrams])
return shuffled_ex
def _handle_image_4shuffle(x):
return_image = False
if not isinstance(x, torch.Tensor):
# print(f"x is not a tensor: {type(x)}. Trying to handle but fix this or I'll annoy you with this log")
t = torch.tensor(np.array(x)).unsqueeze(dim=0).float()
t = t.permute(0, 3, 1, 2)
return_image = True
return t, return_image
if len(x.shape) != 4:
#print("You did not send a tensor of shape NxCxWxH. Unsqueezing not but fix this or I'll annoy you with this log")
return x.unsqueeze(dim=0), return_image
else:
# Good boi
return x, return_image
def shuffle_rows(x, n_rows=7):
"""
Shuffle the rows of the image tensor where each row has a size of 14 pixels.
Tensor is of shape N x C x W x H
"""
x, return_image = _handle_image_4shuffle(x)
patch_size = x.shape[-2]//n_rows
u = nnf.unfold(x, kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
# permute the patches of each image in the batch
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
# fold the permuted patches back together
f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
image = f.squeeze() # C W H
if return_image:
return T.ToPILImage()(image.type(torch.uint8))
else:
return image
def shuffle_columns(x, n_cols=7):
"""
Shuffle the columns of the image tensor where we'll have n_cols columns.
Tensor is of shape N x C x W x H
"""
x, return_image = _handle_image_4shuffle(x)
patch_size = x.shape[-1]//n_cols
u = nnf.unfold(x, kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
# permute the patches of each image in the batch
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
# fold the permuted patches back together
f = nnf.fold(pu, x.shape[-2:], kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
image = f.squeeze() # C W H
if return_image:
return T.ToPILImage()(image.type(torch.uint8))
else:
return image
def shuffle_patches(x, n_ratio=4):
"""
Shuffle the rows of the image tensor where each row has a size of 14 pixels.
Tensor is of shape N x C x W x H
"""
x, return_image = _handle_image_4shuffle(x)
patch_size_x = x.shape[-2]//n_ratio
patch_size_y = x.shape[-1]//n_ratio
u = nnf.unfold(x, kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
# permute the patches of each image in the batch
pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
# fold the permuted patches back together
f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
image = f.squeeze() # C W H
if return_image:
return T.ToPILImage()(image.type(torch.uint8))
else:
return image