Spaces:
Runtime error
Runtime error
# flake8: noqa | |
import copy | |
import json | |
import random | |
from pathlib import Path | |
from pprint import pprint | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
def init_random_input(len_range: int = 5, value_gen=5) -> list: | |
len_gen = random.randint(2, len_range + 1) | |
value_range = list(range(-value_gen, value_gen + 1)) | |
output = [] | |
for index in range(len_gen): | |
value_gen = random.choice(value_range) | |
output.append(value_gen) | |
return output | |
const_integer = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] | |
# Functions in the DSL | |
# Each function defines a transformation in the given DSL Grammar. | |
def take(input_list: list, n: int) -> list: | |
return input_list[:n] | |
def drop(input_list: list, n: int) -> list: | |
return input_list[n:] | |
def minimum(input_list: list) -> int: | |
return min(input_list) | |
def maximum(input_list: list) -> int: | |
return max(input_list) | |
def reverse(input_list: list) -> list: | |
return input_list[::-1] | |
def sort_asc(input_list: list) -> list: | |
return sorted(input_list) | |
def sort_des(input_list: list) -> list: | |
return sorted(input_list, reverse=True) | |
def add_n(input_list: list, n: int) -> list: | |
return [x + n for x in input_list] | |
def sub_n(input_list: list, n: int) -> list: | |
return [x - n for x in input_list] | |
def mul_n(input_list: list, n: int) -> list: | |
return [x * n for x in input_list] | |
def div_n(input_list: list, n: int) -> list: | |
return [x / n for x in input_list] | |
def expand_copy(input_list: list) -> list: | |
return input_list + input_list | |
# Main Production Rules for the Toy DSL. | |
list_manip_dsl = { | |
"take": take, | |
"drop": drop, | |
"reverse": reverse, | |
"sort_asc": sort_asc, | |
"sort_des": sort_des, | |
"add_n": add_n, | |
"sub_n": sub_n, | |
"mul_n": mul_n, | |
"expand_copy": expand_copy, | |
} | |
# Use this class to execute programs written in the DSL. | |
class Interpreter: | |
def __init__(self) -> None: | |
self.parser = list_manip_dsl | |
def __call__(self, statement_string: str): | |
""" | |
Evaluation Function for the interpreter. | |
args: | |
statement_string (str) : Statement String | |
""" | |
try: | |
return eval(statement_string) # Adding an exception to unparsable strings | |
except: | |
return "ERROR" | |
interpreter = Interpreter() | |
# TEMPLATE | |
# This is used to store the input, output and the function template. | |
# Input : List given as an input to the function. | |
# function_template : The atomic function in a given DSL Grammar | |
# Output : Transformed outut by applying function on the input. | |
generation_template = {"function_template": "NONE", "output": "NONE", "input": []} | |
# Each of the generate function is used to generate a | |
# template for a given function | |
# if chosen while sampling the dataset. | |
# each function takes in expressions based on the grammar and generates a template. | |
# Example: gen_take() generates a template for the take function. | |
# take function has two arguments, | |
# list_expression and a bounded integer(Should not be more | |
# than the length of the list).. | |
def gen_take(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(range(1, len(expr1) - 1)) | |
formatted_fn = f"take({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
def gen_drop(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(range(1, len(expr1) - 1)) | |
formatted_fn = f"drop({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
def gen_minimum(expr1=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
formatted_fn = f"minimum({expr1})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1] | |
return template | |
def gen_maximum(expr1=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
formatted_fn = f"maximum({expr1})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1] | |
return template | |
def gen_reverse(expr1=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
formatted_fn = f"reverse({expr1})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1] | |
return template | |
def gen_sort_asc(expr1=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
formatted_fn = f"sort_asc({expr1})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1] | |
return template | |
def gen_sort_des(expr1=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
formatted_fn = f"sort_des({expr1})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1] | |
return template | |
def gen_add_n(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(const_integer) | |
formatted_fn = f"add_n({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
def gen_sub_n(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(const_integer) | |
formatted_fn = f"sub_n({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
def gen_mul_n(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(const_integer) | |
formatted_fn = f"mul_n({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
def gen_div_n(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(const_integer) | |
formatted_fn = f"div_n({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
def gen_expand_copy(expr1=None, expr2=None): | |
if expr1 == None: | |
expr1 = init_random_input() | |
if expr2 == None: | |
expr2 = random.choice(range(1, 3)) | |
formatted_fn = f"expand_copy({expr1},{expr2})" | |
template = copy.copy(generation_template) | |
template["function_template"] = formatted_fn | |
template["output"] = interpreter(formatted_fn) | |
template["input"] = [expr1, expr2] | |
return template | |
list_manip_dsl_gen = { | |
"take": gen_take, | |
"drop": gen_drop, | |
"minimum": gen_minimum, | |
"maximum": gen_maximum, | |
"reverse": gen_reverse, | |
"sort_asc": gen_sort_asc, | |
"sort_des": gen_sort_des, | |
"add_n": gen_add_n, | |
"sub_n": gen_sub_n, | |
"mul_n": gen_mul_n, | |
"div_n": gen_div_n, | |
"expand_copy": gen_expand_copy, | |
} | |
class Sampler: | |
def __init__( | |
self, | |
max_sample_length: int = 5, | |
code_sep: str = ";", | |
interpreter_sep: str = "->", | |
): | |
self.max_sample_length = max_sample_length | |
self.parser = Interpreter() | |
self.production_list = list_manip_dsl | |
self.production_idt = [i for i in self.production_list.keys()] | |
self.production_gen_list = list_manip_dsl_gen | |
self.code_sep = code_sep | |
self.interpreter_sep = interpreter_sep | |
def sample_production(self, gen_length: int = 5): | |
init_flag = True | |
hash_functions = [] | |
if gen_length == None: | |
gen_length = self.max_sample_length | |
for ind in range(gen_length): | |
if init_flag: | |
random_chosen_function = random.choice(self.production_idt) | |
generated_function = self.production_gen_list[random_chosen_function]() | |
hash_functions.append(generated_function) | |
init_flag = False | |
else: | |
random_chosen_function = random.choice(self.production_idt) | |
generated_function = self.production_gen_list[random_chosen_function]( | |
hash_functions[-1]["function_template"] | |
) | |
if generated_function["output"] == "ERROR": | |
break | |
hash_functions.append(generated_function) | |
return hash_functions | |
def create_synthetic_dataset(size: int, io_size=3) -> dict: | |
output_list = [] | |
sampler = Sampler() | |
for i in tqdm(range(size)): | |
try: | |
sampled = sampler.sample_production() | |
inp = sampled[0]["input"][0] | |
out = sampled[-1]["output"] | |
function = sampled[-1]["function_template"] | |
prompt_inp = f"Input: {inp} Output: {out} Function:" | |
prompt_out = function | |
if out != [] and out != "ERROR": | |
output_list.append( | |
{ | |
"input": prompt_inp, | |
"output": prompt_out, | |
"io_inp": inp, | |
"io_out": out, | |
} | |
) | |
except: | |
pass | |
return output_list | |
def write_to_json(data: dict, file_name: str): | |
with open(file_name, "w") as f: | |
json.dump(data, f, indent=2) | |
def basic_stats(dataset, tokenizer): | |
""" | |
Basic stats to calculate the token length of the dataset. | |
""" | |
length_list = [] | |
for examples in tqdm(dataset): | |
datapoint = tokenizer(examples["input"] + " " + examples["output"] + "<|endoftext|>") | |
length_list.append(len(datapoint["input_ids"])) | |
return { | |
"max": max(length_list), | |
"min": min(length_list), | |
"mean": sum(length_list) / len(length_list), | |
} | |
if __name__ == "__main__": | |
# sampler = Sampler() | |
# pprint(sampler.sample_production()) | |
# pprint(interpreter("div_n(reverse([-2, -5, -4]),1)")) | |
train_data = create_synthetic_dataset(2000000) | |
test_data = create_synthetic_dataset(2_000) | |
print(f"Train data size: {len(train_data)}") | |
print(f"Test data size: {len(test_data)}") | |
Path("dataset").mkdir(parents=True, exist_ok=True) | |
write_to_json(train_data, "dataset/train.json") | |
write_to_json(test_data, "dataset/test.json") | |