teachyourselfcoding's picture
Upload 245 files
fa6856c
# flake8: noqa
import copy
import json
import random
from pathlib import Path
from pprint import pprint
from tqdm import tqdm
from transformers import AutoTokenizer
def init_random_input(len_range: int = 5, value_gen=5) -> list:
len_gen = random.randint(2, len_range + 1)
value_range = list(range(-value_gen, value_gen + 1))
output = []
for index in range(len_gen):
value_gen = random.choice(value_range)
output.append(value_gen)
return output
const_integer = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]
# Functions in the DSL
# Each function defines a transformation in the given DSL Grammar.
def take(input_list: list, n: int) -> list:
return input_list[:n]
def drop(input_list: list, n: int) -> list:
return input_list[n:]
def minimum(input_list: list) -> int:
return min(input_list)
def maximum(input_list: list) -> int:
return max(input_list)
def reverse(input_list: list) -> list:
return input_list[::-1]
def sort_asc(input_list: list) -> list:
return sorted(input_list)
def sort_des(input_list: list) -> list:
return sorted(input_list, reverse=True)
def add_n(input_list: list, n: int) -> list:
return [x + n for x in input_list]
def sub_n(input_list: list, n: int) -> list:
return [x - n for x in input_list]
def mul_n(input_list: list, n: int) -> list:
return [x * n for x in input_list]
def div_n(input_list: list, n: int) -> list:
return [x / n for x in input_list]
def expand_copy(input_list: list) -> list:
return input_list + input_list
# Main Production Rules for the Toy DSL.
list_manip_dsl = {
"take": take,
"drop": drop,
"reverse": reverse,
"sort_asc": sort_asc,
"sort_des": sort_des,
"add_n": add_n,
"sub_n": sub_n,
"mul_n": mul_n,
"expand_copy": expand_copy,
}
# Use this class to execute programs written in the DSL.
class Interpreter:
def __init__(self) -> None:
self.parser = list_manip_dsl
def __call__(self, statement_string: str):
"""
Evaluation Function for the interpreter.
args:
statement_string (str) : Statement String
"""
try:
return eval(statement_string) # Adding an exception to unparsable strings
except:
return "ERROR"
interpreter = Interpreter()
# TEMPLATE
# This is used to store the input, output and the function template.
# Input : List given as an input to the function.
# function_template : The atomic function in a given DSL Grammar
# Output : Transformed outut by applying function on the input.
generation_template = {"function_template": "NONE", "output": "NONE", "input": []}
# Each of the generate function is used to generate a
# template for a given function
# if chosen while sampling the dataset.
# each function takes in expressions based on the grammar and generates a template.
# Example: gen_take() generates a template for the take function.
# take function has two arguments,
# list_expression and a bounded integer(Should not be more
# than the length of the list)..
def gen_take(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(range(1, len(expr1) - 1))
formatted_fn = f"take({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
def gen_drop(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(range(1, len(expr1) - 1))
formatted_fn = f"drop({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
def gen_minimum(expr1=None):
if expr1 == None:
expr1 = init_random_input()
formatted_fn = f"minimum({expr1})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1]
return template
def gen_maximum(expr1=None):
if expr1 == None:
expr1 = init_random_input()
formatted_fn = f"maximum({expr1})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1]
return template
def gen_reverse(expr1=None):
if expr1 == None:
expr1 = init_random_input()
formatted_fn = f"reverse({expr1})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1]
return template
def gen_sort_asc(expr1=None):
if expr1 == None:
expr1 = init_random_input()
formatted_fn = f"sort_asc({expr1})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1]
return template
def gen_sort_des(expr1=None):
if expr1 == None:
expr1 = init_random_input()
formatted_fn = f"sort_des({expr1})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1]
return template
def gen_add_n(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(const_integer)
formatted_fn = f"add_n({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
def gen_sub_n(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(const_integer)
formatted_fn = f"sub_n({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
def gen_mul_n(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(const_integer)
formatted_fn = f"mul_n({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
def gen_div_n(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(const_integer)
formatted_fn = f"div_n({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
def gen_expand_copy(expr1=None, expr2=None):
if expr1 == None:
expr1 = init_random_input()
if expr2 == None:
expr2 = random.choice(range(1, 3))
formatted_fn = f"expand_copy({expr1},{expr2})"
template = copy.copy(generation_template)
template["function_template"] = formatted_fn
template["output"] = interpreter(formatted_fn)
template["input"] = [expr1, expr2]
return template
list_manip_dsl_gen = {
"take": gen_take,
"drop": gen_drop,
"minimum": gen_minimum,
"maximum": gen_maximum,
"reverse": gen_reverse,
"sort_asc": gen_sort_asc,
"sort_des": gen_sort_des,
"add_n": gen_add_n,
"sub_n": gen_sub_n,
"mul_n": gen_mul_n,
"div_n": gen_div_n,
"expand_copy": gen_expand_copy,
}
class Sampler:
def __init__(
self,
max_sample_length: int = 5,
code_sep: str = ";",
interpreter_sep: str = "->",
):
self.max_sample_length = max_sample_length
self.parser = Interpreter()
self.production_list = list_manip_dsl
self.production_idt = [i for i in self.production_list.keys()]
self.production_gen_list = list_manip_dsl_gen
self.code_sep = code_sep
self.interpreter_sep = interpreter_sep
def sample_production(self, gen_length: int = 5):
init_flag = True
hash_functions = []
if gen_length == None:
gen_length = self.max_sample_length
for ind in range(gen_length):
if init_flag:
random_chosen_function = random.choice(self.production_idt)
generated_function = self.production_gen_list[random_chosen_function]()
hash_functions.append(generated_function)
init_flag = False
else:
random_chosen_function = random.choice(self.production_idt)
generated_function = self.production_gen_list[random_chosen_function](
hash_functions[-1]["function_template"]
)
if generated_function["output"] == "ERROR":
break
hash_functions.append(generated_function)
return hash_functions
def create_synthetic_dataset(size: int, io_size=3) -> dict:
output_list = []
sampler = Sampler()
for i in tqdm(range(size)):
try:
sampled = sampler.sample_production()
inp = sampled[0]["input"][0]
out = sampled[-1]["output"]
function = sampled[-1]["function_template"]
prompt_inp = f"Input: {inp} Output: {out} Function:"
prompt_out = function
if out != [] and out != "ERROR":
output_list.append(
{
"input": prompt_inp,
"output": prompt_out,
"io_inp": inp,
"io_out": out,
}
)
except:
pass
return output_list
def write_to_json(data: dict, file_name: str):
with open(file_name, "w") as f:
json.dump(data, f, indent=2)
def basic_stats(dataset, tokenizer):
"""
Basic stats to calculate the token length of the dataset.
"""
length_list = []
for examples in tqdm(dataset):
datapoint = tokenizer(examples["input"] + " " + examples["output"] + "<|endoftext|>")
length_list.append(len(datapoint["input_ids"]))
return {
"max": max(length_list),
"min": min(length_list),
"mean": sum(length_list) / len(length_list),
}
if __name__ == "__main__":
# sampler = Sampler()
# pprint(sampler.sample_production())
# pprint(interpreter("div_n(reverse([-2, -5, -4]),1)"))
train_data = create_synthetic_dataset(2000000)
test_data = create_synthetic_dataset(2_000)
print(f"Train data size: {len(train_data)}")
print(f"Test data size: {len(test_data)}")
Path("dataset").mkdir(parents=True, exist_ok=True)
write_to_json(train_data, "dataset/train.json")
write_to_json(test_data, "dataset/test.json")