import json import logging import pathlib import yaml from lang import Interpreter import trlx from trlx.data.configs import TRLConfig logger = logging.getLogger(__name__) class DSLDataset: def __init__(self): with open("dataset/train.json", "r") as f: self.train_data = json.load(f) with open("dataset/test.json", "r") as f: self.test_data = json.load(f) logger.info("Sucessfully loaded the dataset") def load_datapoints(self, split="train"): if split == "train": for datapoint in self.train_data: if "ERROR" not in datapoint["input"]: yield datapoint["input"] elif split == "test": for datapoint in self.test_data: yield datapoint["input"] interpreter = Interpreter() def reward_fn(samples, **kwargs): reward_list = [] for sample in samples: code = sample.split("Function:")[1].strip() output = eval(sample.split("Output:")[1].strip().split("Function:")[0].strip()) interpreted_output = interpreter(code) if interpreted_output == "ERROR": # If the code is unparsable, we give it a negative reward. reward_list.append(-1) else: # if the code is parseable if output == interpreted_output: # if the output is correct, we give it a positive reward. reward_list.append(1) else: # if the output is incorrect, we give it a negative reward. reward_list.append(-0.5) return reward_list config_path = pathlib.Path(__file__).parent.joinpath("configs/trlx_ppo_config.yml") with config_path.open() as f: default_config = yaml.safe_load(f) def main(hparams={}): config = TRLConfig.update(default_config, hparams) # Dataset dataset = DSLDataset() train_prompts = list(dataset.load_datapoints(split="train"))[:1000] trainer = trlx.train( reward_fn=reward_fn, prompts=train_prompts, config=config, ) trainer.save_pretrained("dataset/trained_model") if __name__ == "__main__": # TEST REWARD FUNTION assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"])) == [1] assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"])) == [-1] assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"])) == [-0.5] main()