|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""pp ops.""" |
|
|
|
from big_vision.pp.registry import Registry |
|
import tensorflow as tf |
|
|
|
|
|
@Registry.register('preprocess_ops.sci_qa_choices_shuffle') |
|
def sci_qa_choices_shuffle( |
|
choice_str_inkey='choices', |
|
ans_inkey='answer', |
|
indexed_choices_outkey='indexed_choices', |
|
indexed_answer_outkey='indexed_answer', |
|
): |
|
"""Random shuffle the sci_qa's choice on the fly. |
|
|
|
Args: |
|
choice_str_inkey: the original choice list from |
|
sciqa,e.g['apple','banana',..] |
|
ans_inkey: the original answer from sciqa e.g. 1 |
|
indexed_choices_outkey: shuffled choice (with index suffix concat to string) |
|
e.g."(A) banana, (B) apple" |
|
indexed_answer_outkey: shuffled answer with abc index, e,g |
|
1(original)->2(shuffled)->'B' (alphabet index) |
|
|
|
Returns: |
|
""" |
|
def _template(data): |
|
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' |
|
abc_tensor = tf.constant([f'({a})' for a in alphabet]) |
|
abcans_tensor = tf.constant([f'{a}' for a in alphabet]) |
|
choices = data[choice_str_inkey] |
|
indices = tf.range(len(choices)) |
|
|
|
shuffled_indices = tf.random.shuffle(indices) |
|
|
|
shuffled_tensor = tf.gather(choices, shuffled_indices) |
|
|
|
abc_tensor = tf.gather(abc_tensor, indices) |
|
|
|
data[indexed_choices_outkey] = tf.strings.reduce_join( |
|
tf.strings.join([abc_tensor, shuffled_tensor], separator=' '), |
|
separator=', ', |
|
) |
|
|
|
answer_tensor = data[ans_inkey] |
|
new_ans_indice = tf.where(tf.equal(shuffled_indices, answer_tensor)) |
|
new_ans_indice = tf.gather(abcans_tensor, new_ans_indice) |
|
data[indexed_answer_outkey] = tf.strings.reduce_join(new_ans_indice) |
|
return data |
|
|
|
return _template |
|
|