# flake8: noqa import copy import json import random from pathlib import Path from pprint import pprint from tqdm import tqdm from transformers import AutoTokenizer def init_random_input(len_range: int = 5, value_gen=5) -> list: len_gen = random.randint(2, len_range + 1) value_range = list(range(-value_gen, value_gen + 1)) output = [] for index in range(len_gen): value_gen = random.choice(value_range) output.append(value_gen) return output const_integer = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] # Functions in the DSL # Each function defines a transformation in the given DSL Grammar. def take(input_list: list, n: int) -> list: return input_list[:n] def drop(input_list: list, n: int) -> list: return input_list[n:] def minimum(input_list: list) -> int: return min(input_list) def maximum(input_list: list) -> int: return max(input_list) def reverse(input_list: list) -> list: return input_list[::-1] def sort_asc(input_list: list) -> list: return sorted(input_list) def sort_des(input_list: list) -> list: return sorted(input_list, reverse=True) def add_n(input_list: list, n: int) -> list: return [x + n for x in input_list] def sub_n(input_list: list, n: int) -> list: return [x - n for x in input_list] def mul_n(input_list: list, n: int) -> list: return [x * n for x in input_list] def div_n(input_list: list, n: int) -> list: return [x / n for x in input_list] def expand_copy(input_list: list) -> list: return input_list + input_list # Main Production Rules for the Toy DSL. list_manip_dsl = { "take": take, "drop": drop, "reverse": reverse, "sort_asc": sort_asc, "sort_des": sort_des, "add_n": add_n, "sub_n": sub_n, "mul_n": mul_n, "expand_copy": expand_copy, } # Use this class to execute programs written in the DSL. class Interpreter: def __init__(self) -> None: self.parser = list_manip_dsl def __call__(self, statement_string: str): """ Evaluation Function for the interpreter. args: statement_string (str) : Statement String """ try: return eval(statement_string) # Adding an exception to unparsable strings except: return "ERROR" interpreter = Interpreter() # TEMPLATE # This is used to store the input, output and the function template. # Input : List given as an input to the function. # function_template : The atomic function in a given DSL Grammar # Output : Transformed outut by applying function on the input. generation_template = {"function_template": "NONE", "output": "NONE", "input": []} # Each of the generate function is used to generate a # template for a given function # if chosen while sampling the dataset. # each function takes in expressions based on the grammar and generates a template. # Example: gen_take() generates a template for the take function. # take function has two arguments, # list_expression and a bounded integer(Should not be more # than the length of the list).. def gen_take(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(range(1, len(expr1) - 1)) formatted_fn = f"take({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template def gen_drop(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(range(1, len(expr1) - 1)) formatted_fn = f"drop({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template def gen_minimum(expr1=None): if expr1 == None: expr1 = init_random_input() formatted_fn = f"minimum({expr1})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1] return template def gen_maximum(expr1=None): if expr1 == None: expr1 = init_random_input() formatted_fn = f"maximum({expr1})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1] return template def gen_reverse(expr1=None): if expr1 == None: expr1 = init_random_input() formatted_fn = f"reverse({expr1})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1] return template def gen_sort_asc(expr1=None): if expr1 == None: expr1 = init_random_input() formatted_fn = f"sort_asc({expr1})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1] return template def gen_sort_des(expr1=None): if expr1 == None: expr1 = init_random_input() formatted_fn = f"sort_des({expr1})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1] return template def gen_add_n(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(const_integer) formatted_fn = f"add_n({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template def gen_sub_n(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(const_integer) formatted_fn = f"sub_n({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template def gen_mul_n(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(const_integer) formatted_fn = f"mul_n({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template def gen_div_n(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(const_integer) formatted_fn = f"div_n({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template def gen_expand_copy(expr1=None, expr2=None): if expr1 == None: expr1 = init_random_input() if expr2 == None: expr2 = random.choice(range(1, 3)) formatted_fn = f"expand_copy({expr1},{expr2})" template = copy.copy(generation_template) template["function_template"] = formatted_fn template["output"] = interpreter(formatted_fn) template["input"] = [expr1, expr2] return template list_manip_dsl_gen = { "take": gen_take, "drop": gen_drop, "minimum": gen_minimum, "maximum": gen_maximum, "reverse": gen_reverse, "sort_asc": gen_sort_asc, "sort_des": gen_sort_des, "add_n": gen_add_n, "sub_n": gen_sub_n, "mul_n": gen_mul_n, "div_n": gen_div_n, "expand_copy": gen_expand_copy, } class Sampler: def __init__( self, max_sample_length: int = 5, code_sep: str = ";", interpreter_sep: str = "->", ): self.max_sample_length = max_sample_length self.parser = Interpreter() self.production_list = list_manip_dsl self.production_idt = [i for i in self.production_list.keys()] self.production_gen_list = list_manip_dsl_gen self.code_sep = code_sep self.interpreter_sep = interpreter_sep def sample_production(self, gen_length: int = 5): init_flag = True hash_functions = [] if gen_length == None: gen_length = self.max_sample_length for ind in range(gen_length): if init_flag: random_chosen_function = random.choice(self.production_idt) generated_function = self.production_gen_list[random_chosen_function]() hash_functions.append(generated_function) init_flag = False else: random_chosen_function = random.choice(self.production_idt) generated_function = self.production_gen_list[random_chosen_function]( hash_functions[-1]["function_template"] ) if generated_function["output"] == "ERROR": break hash_functions.append(generated_function) return hash_functions def create_synthetic_dataset(size: int, io_size=3) -> dict: output_list = [] sampler = Sampler() for i in tqdm(range(size)): try: sampled = sampler.sample_production() inp = sampled[0]["input"][0] out = sampled[-1]["output"] function = sampled[-1]["function_template"] prompt_inp = f"Input: {inp} Output: {out} Function:" prompt_out = function if out != [] and out != "ERROR": output_list.append( { "input": prompt_inp, "output": prompt_out, "io_inp": inp, "io_out": out, } ) except: pass return output_list def write_to_json(data: dict, file_name: str): with open(file_name, "w") as f: json.dump(data, f, indent=2) def basic_stats(dataset, tokenizer): """ Basic stats to calculate the token length of the dataset. """ length_list = [] for examples in tqdm(dataset): datapoint = tokenizer(examples["input"] + " " + examples["output"] + "<|endoftext|>") length_list.append(len(datapoint["input_ids"])) return { "max": max(length_list), "min": min(length_list), "mean": sum(length_list) / len(length_list), } if __name__ == "__main__": # sampler = Sampler() # pprint(sampler.sample_production()) # pprint(interpreter("div_n(reverse([-2, -5, -4]),1)")) train_data = create_synthetic_dataset(2000000) test_data = create_synthetic_dataset(2_000) print(f"Train data size: {len(train_data)}") print(f"Test data size: {len(test_data)}") Path("dataset").mkdir(parents=True, exist_ok=True) write_to_json(train_data, "dataset/train.json") write_to_json(test_data, "dataset/test.json")