{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "%pwd\n",
    "os.chdir(\"../\")\n",
    "\n",
    "%pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "@dataclass(frozen=True)\n",
    "class ModelTrainerConfig:\n",
    "    root_dir: str\n",
    "    data_path: str\n",
    "    model_ckpt: str\n",
    "    num_train_epochs: int\n",
    "    warmup_steps: int\n",
    "    per_device_train_batch_size: int\n",
    "    weight_decay: float\n",
    "    logging_steps: int\n",
    "    evaluation_strategy: str\n",
    "    eval_steps: int\n",
    "    save_steps: float\n",
    "    gradient_accumulation_steps: int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from box import ConfigBox\n",
    "from pathlib import Path\n",
    "from src.TextSummarizer.constants import file_path\n",
    "from src.TextSummarizer.utils.general import read_yaml, create_directories\n",
    "\n",
    "\n",
    "class ConfigurationManager:\n",
    "\n",
    "    def __init__(self) -> None:\n",
    "        self.config: ConfigBox = read_yaml(Path(file_path.CONFIG_FILE_PATH))\n",
    "        self.params: ConfigBox = read_yaml(Path(file_path.PARAMS_FILE_PATH))\n",
    "\n",
    "        create_directories(path_to_directories=[self.config.artifacts_root])\n",
    "\n",
    "    def get_model_trainer_config(self) -> ModelTrainerConfig:\n",
    "        config = self.config.model_trainer\n",
    "        params = self.params.TrainingArguments\n",
    "\n",
    "        create_directories([config.root_dir])\n",
    "\n",
    "        model_trainer_config = ModelTrainerConfig(\n",
    "            root_dir=config.root_dir,\n",
    "            data_path=config.data_path,\n",
    "            model_ckpt = config.model_ckpt,\n",
    "            num_train_epochs = params.num_train_epochs,\n",
    "            warmup_steps = params.warmup_steps,\n",
    "            per_device_train_batch_size = params.per_device_train_batch_size,\n",
    "            weight_decay = params.weight_decay,\n",
    "            logging_steps = params.logging_steps,\n",
    "            evaluation_strategy = params.evaluation_strategy,\n",
    "            eval_steps = params.evaluation_strategy,\n",
    "            save_steps = params.save_steps,\n",
    "            gradient_accumulation_steps = params.gradient_accumulation_steps\n",
    "        )\n",
    "\n",
    "        return model_trainer_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "from transformers import DataCollatorForSeq2Seq\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "from datasets import load_dataset, load_from_disk\n",
    "import torch\n",
    "\n",
    "\n",
    "class ModelTrainer:\n",
    "    def __init__(self, config: ModelTrainerConfig):\n",
    "        self.config = config\n",
    "\n",
    "\n",
    "\n",
    "    def train(self):\n",
    "        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "        tokenizer = AutoTokenizer.from_pretrained(self.config.model_ckpt)\n",
    "        model_pegasus = AutoModelForSeq2SeqLM.from_pretrained(self.config.model_ckpt).to(device)\n",
    "        seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model_pegasus)\n",
    "\n",
    "        #loading data\n",
    "        dataset = load_from_disk(self.config.data_path)\n",
    "\n",
    "        # trainer_args = TrainingArguments(\n",
    "        #     output_dir=self.config.root_dir, num_train_epochs=self.config.num_train_epochs, warmup_steps=self.config.warmup_steps,\n",
    "        #     per_device_train_batch_size=self.config.per_device_train_batch_size, per_device_eval_batch_size=self.config.per_device_train_batch_size,\n",
    "        #     weight_decay=self.config.weight_decay, logging_steps=self.config.logging_steps,\n",
    "        #     evaluation_strategy=self.config.evaluation_strategy, eval_steps=self.config.eval_steps, save_steps=1e6,\n",
    "        #     gradient_accumulation_steps=self.config.gradient_accumulation_steps\n",
    "        # )\n",
    "\n",
    "\n",
    "        trainer_args = TrainingArguments(\n",
    "            output_dir=self.config.root_dir,\n",
    "            num_train_epochs=1,\n",
    "            warmup_steps=500,\n",
    "            per_device_train_batch_size=1,\n",
    "            per_device_eval_batch_size=1,\n",
    "            weight_decay=0.01,\n",
    "            logging_steps=10,\n",
    "            evaluation_strategy='steps',\n",
    "            eval_steps=500,\n",
    "            save_steps=1e6,\n",
    "            gradient_accumulation_steps=16\n",
    "        )\n",
    "\n",
    "        trainer = Trainer(\n",
    "            model=model_pegasus,\n",
    "            args=trainer_args,\n",
    "            tokenizer=tokenizer,\n",
    "            data_collator=seq2seq_data_collator,\n",
    "            train_dataset=dataset[\"train\"],\n",
    "            eval_dataset=dataset[\"validation\"])\n",
    "\n",
    "        # trainer.train()\n",
    "\n",
    "        ## Save model\n",
    "        model_pegasus.save_pretrained(\"multi-news-model\")\n",
    "\n",
    "        ## Save tokenizer\n",
    "        tokenizer.save_pretrained(\"tokenizer\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    config = ConfigurationManager()\n",
    "    model_trainer_config = config.get_model_trainer_config()\n",
    "    model_trainer_config = ModelTrainer(config=model_trainer_config)\n",
    "    model_trainer_config.train()\n",
    "except Exception as e:\n",
    "    raise e"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}