SeemG commited on
Commit
210c3a8
Β·
verified Β·
1 Parent(s): a5ac7b2

Upload 4 files

Browse files
Files changed (4) hide show
  1. final_gpt2.ipynb +1 -0
  2. gpt_124M_30thJune2024.pth +3 -0
  3. model.py +483 -0
  4. training.ipynb +0 -0
final_gpt2.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":38258,"status":"ok","timestamp":1720292542191,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"OCNt3VoHt5Nq","outputId":"39403a61-6ca2-4654-a084-0a09fc2ba3fa"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":442,"status":"ok","timestamp":1720292714705,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"HSjbX-JgxmOL","outputId":"17a3fd12-c90f-4865-d5a1-b80ca2d119e1"},"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/MyDrive/S21\n"]}],"source":["import os\n","script_dir = os.path.dirname(\"/content/drive/MyDrive/S21/classnotes.ipynb\")\n","# Change the cwd to the script's directory\n","os.chdir(script_dir)\n","\n","# Now the cwd is set to the directory containing the script (and potentially the file)\n","print(os.getcwd()) # This will print the current working directory"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":28808,"status":"ok","timestamp":1720292745598,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"7JbXHrA1tWfr","colab":{"base_uri":"https://localhost:8080/"},"outputId":"89581984-6dca-4d27-9c1a-035e675d4fa8"},"outputs":[{"output_type":"stream","name":"stdout","text":["\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.3/12.3 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.0/92.0 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m318.2/318.2 kB\u001b[0m \u001b[31m28.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m141.1/141.1 kB\u001b[0m \u001b[31m16.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m42.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.4/62.4 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.9/71.9 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.6/53.6 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.7/307.7 kB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━��━━━━━━━━━\u001b[0m \u001b[32m341.4/341.4 kB\u001b[0m \u001b[31m37.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m67.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m57.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h"]}],"source":["!pip install -q gradio\n","!pip install -q tiktoken"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":442,"status":"ok","timestamp":1720292788608,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"_sOT51G4tOrc"},"outputs":[],"source":["from transformers import GPT2LMHeadModel\n","import tiktoken\n","import torch\n","import gradio as gr\n","import model\n","from model import run_train, gen_text"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":2,"status":"ok","timestamp":1720292789731,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"LGko_SepX3Pv"},"outputs":[],"source":["# Specify a path\n","PATH = \"/content/drive/MyDrive/S21/gpt_124M_30thJune2024.pth\"\n"]},{"cell_type":"code","execution_count":8,"metadata":{"executionInfo":{"elapsed":775,"status":"ok","timestamp":1720292791316,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"FICKsglhZHr3"},"outputs":[],"source":["class GPTConfig:\n"," block_size: int = 1024 # max sequence length\n"," vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n"," n_layer: int = 12 # number of layers\n"," n_head: int = 12 # number of heads\n"," n_embd: int = 768 # embedding dimension"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"txAZmr2qYdLH","executionInfo":{"status":"ok","timestamp":1720292814216,"user_tz":-330,"elapsed":22902,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"3f0f8fed-5cd5-4591-fadb-bb0d542f41d2"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["<All keys matched successfully>"]},"metadata":{},"execution_count":9}],"source":["device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","model2 = model.GPT(GPTConfig())\n","model2 = model2.to(device)\n","# model2.load_state_dict(torch.load(PATH),map_location=torch.device('cpu'))\n","model2.load_state_dict(torch.load('/content/drive/MyDrive/S21/gpt_124M_1.pth', map_location=torch.device(device)))\n"]},{"cell_type":"code","execution_count":17,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1720205456686,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"IdQe0IQdY9km"},"outputs":[],"source":[]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":28447,"status":"ok","timestamp":1720275103624,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"qJjvfdWfiORi","outputId":"4c1baccf-638f-427e-c272-1001231c25d3"},"outputs":[{"output_type":"stream","name":"stdout","text":["> The adventurer delved into the dark..., speak there's faithful madward.\n","\n","Now, and most Lady hereafter sent her well; call the foot of such a good,\n","For you,\n","If all from the city butchers\n","By\n","> The adventurer delved into the dark...,\n","For that have you keep it, webusiness and no truth?\n","Sir more sort of the mind\n","As presently you, you have had a merry years\n","That venture-CAMILLO:\n","> The adventurer delved into the dark..., hang it have possible\n","Bestlingied below the conspe;\n","For that come on the fain of no\n","And all, give and match from me\n","Should now upon my worthy weak-women\n","> The adventurer delved into the dark..., take them go approaches\n","thanable; then no worst, hence comes here's sorrow your humblebts:\n","Had been acquainted with such badxth-ower.\n","\n","Sir, from him can\n","> The adventurer delved into the dark..., tell\n","For since that such virtue,\n","For great purpose to the king's another ill-t passion express the best and we shall\n","Which never for me.\n","\n","GLOUCESTER:\n"]}],"source":["start_tokens = \"The adventurer delved into the dark...\" # You can provide a list of token IDs as a starting prompt here\n","max_length = 50 # Maximum length of the generated text\n","num_return_sequences = 5 # Number of text sequences to generate\n","\n","decoded = gen_text(model2, start_tokens, max_length, num_return_sequences)"]},{"cell_type":"code","execution_count":16,"metadata":{"id":"hfbynK20i6NN","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1720206852956,"user_tz":-330,"elapsed":214728,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"e21afa9f-35fc-4700-98cb-4573eed170cd"},"outputs":[{"output_type":"stream","name":"stdout","text":["> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and deign of a wont\n","For now remain\n","For from men and that pluck supper buried-- no comfort,\n","If the king but thence and me me.\n","How she would come, come home;\n","ESC? which you have no intent\n","These two men the self,\n","aeth o'er was a fellow of enough,\n","Master\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and drops\n","Were, men and noble Edward's children\n","Were half i' the blood it,\n","By knave just and all report about\n","She have no;\n","Would our battle through the pity set,\n","the entertainment of the harm\n","He will, for your vowsken feel,\n","Were you; and wowet doth against this which he\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and Capitol\n","And as become blepper\n","By master, from a more dead in,\n","thee\n","shOSPERO, you have no Rome.\n","Go give'd the sad-p'der-'sThen never stood by and in\n","In this business\n","The power of a slave,\n","In the keep your best stock proclaim'd from it out\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and god grant it must be it's\n","Upon men, here\n","Was the shame on eane it made\n","PERDITA enough to me me, take:\n","For it beggar\n","Were you do it out my knaves\n","One two as an necessaryves to leave,\n","Should warm and fear of the bisson diseases\n","the Angelo:\n","The\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me.\n","If, come before there's envy\n","Like bloubt about you have before indeed\n","She'er\n","A shrewd's crown-s made together, and bad daughter,\n","I have made a mortalvish harm hereafter a goods\n","Thou know, a welcome home\n","She was, it;\n","Who from the king; go walk upon me,\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and pant\n","Will; goest indeed we hidness, from himself hope to such were thus Lord till\n","Were not grant of this shows,\n","Those may bades and sl reward\n","This bed, bid them the intent\n","Of Clifford hence\n"," noble is content\n","When our country's daughter,\n","Or both ere lady's an enm; go alone\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; norCor virtuous will perform.'\n","For there they have no master before great, and rascalPeace of it\n","Foraleth it deserved thence\n","She have no, as this young marketew, to me,\n","For it straight here in a sea: what,\n"," quainMost lose\n"," answer you; and wretch out no,\n","Were, and\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and drops\n","For fit it gave-attRather ways; then be this most brve bent\n","But for him up the hand-w.\n","Meth you have his voice of of women gives\n","The thousands that thy not\n","CLEOMAS: when, for as the sound,\n","That have no hast no breast.\n","\n","motionken composition arms,\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and Vol Angelo\n","MIR's dead.\n","\n"," power,\n","the hour possessed within, my master within which dined and the faults\n","Yes, be still,\n","My sister's before me me, and most ha me,\n","Cry\n"," use the;\n","When it dream, again's answer 'em dragon deliver ta'er\n","Either neither say you\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me;\n","Indeed,' to a particular's perish\n","Not ang i' the worst of some prepare under a woman in,\n","Which you have set down i'en the field\n","To take the horses'd me, and by the better-night\n","And bring Baptista enough to give unto like to be to take at it be\n","When it you'll give me\n","\n"]}],"source":["start_tokens = \"The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity...\" # You can provide a list of token IDs as a starting prompt here\n","max_length = 100 # Maximum length of the generated text\n","num_return_sequences = 10 # Number of text sequences to generate\n","\n","gen_text(model2, start_tokens, max_length, num_return_sequences)\n"]},{"cell_type":"code","execution_count":13,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":646},"executionInfo":{"elapsed":1189,"status":"ok","timestamp":1720276368008,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"1u3NJcMouwsx","outputId":"4197a764-cbe0-4115-a3e9-c2dda8a4d515"},"outputs":[{"output_type":"stream","name":"stdout","text":["Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n","\n","Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n","Running on public URL: https://e6c7439a4e347d91e9.gradio.live\n","\n","This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"]},{"output_type":"display_data","data":{"text/plain":["<IPython.core.display.HTML object>"],"text/html":["<div><iframe src=\"https://e6c7439a4e347d91e9.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":[]},"metadata":{},"execution_count":13}],"source":["import gradio as gr\n","from transformers import GPT2Tokenizer, AutoModelForCausalLM\n","\n","\n","start_tokens = \"The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity...\" # You can provide a list of token IDs as a starting prompt here\n","max_length = 100 # Maximum length of the generated text\n","num_return_sequences = 10 # Number of text sequences to generate\n","\n","# gen_text(model2, start_tokens, max_length, num_return_sequences)\n","\n","# Define generation function\n","def generate_text(prompt):\n"," start_tokens = prompt\n"," output = gen_text(model2, start_tokens, max_length, num_return_sequences)\n"," #return tokenizer.decode(output[0], skip_special_tokens=True)\n"," return output\n","\n","# Humorous prompt options\n","prompts = [\n"," \"Karen, armed with a coupon...\",\n"," \"Gary the goldfish, tired...\",\n"," \"The local news team received...\",\n"," \"In a shocking turn of events...\",\n"," \"The kingdom's annual jousting...\",\n"," \"The sentient toaster, tired...\",\n"," \"Feeling underappreciated...\",\n"," \"The fortune cookie factory...\"\n","]\n","\n","# Gradio interface with dropdown for prompt selection\n","interface = gr.Interface(\n"," fn=generate_text,\n"," inputs=gr.Dropdown(choices=prompts, label=\"Humorous Prompt\"),\n"," outputs=\"text\",\n"," title=\"Humorous Text Generator with GPT-2\",\n"," description=\"Get a chuckle with AI-generated humorous stories based on your chosen prompt.\"\n",")\n","\n","# Launch the Gradio app\n","interface.launch()"]},{"cell_type":"code","source":["# Define generation function\n","def generate_text(prompt, max_length=50, num_return_sequences=10):\n"," \"\"\"Generates humorous text using the GPT-2 model based on the provided prompt and user-specified parameters.\n","\n"," Args:\n"," prompt (str): The starting text for the generation.\n"," max_length (int, optional): The maximum length of the generated text. Defaults to 100.\n"," num_return_sequences (int, optional): The number of different text sequences to generate. Defaults to 10.\n","\n"," Returns:\n"," list: A list of generated humorous text sequences.\n"," \"\"\"\n","\n"," start_tokens = prompt\n"," # generated_texts = model.generate(\n"," # inputs, max_length=max_length, num_return_sequences=num_return_sequences\n"," # )\n"," generated_texts = gen_text(model2, start_tokens, max_length, num_return_sequences)\n"," return generated_texts\n","\n","# Humorous prompt options\n","prompts = [\n"," \"The automatic doors at the grocery store, tired of people holding them open for conversations, developed a mischievous sense of humor.\",\n"," \"The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.\",\n"," \"The fridge goes on a hunger strike\",\n"," \"A colony of ants, inspired by a motivational poster, embarked on a quest to climb the tallest tree in the garden.\",\n"," \"A particularly chatty parrot accidentally spilled the villain's evil plan during a casual conversation with the local mailman.\",\n"," \"TA squirrel declares war on the birdfeeder...\",\n"," \"The refrigerator, overflowing with forgotten groceries, staged a silent protest, refusing to cool anything until some order was restored.\",\n"," \"A fortune cookie predicts world domination\"\n","]\n","\n","# Gradio interface with user inputs and dropdown for prompt selection\n","interface = gr.Interface(\n"," fn=generate_text,\n"," inputs=[\n"," gr.Dropdown(choices=prompts, label=\"Pre-defined Prompt\"),\n"," gr.Slider(minimum=10, maximum=200, label=\"Max Text Length\", value=100, step = 1),\n"," gr.Slider(minimum=1, maximum=20, label=\"Number of Outputs\", value=10,step = 1)\n"," ],\n"," outputs=\"text\",\n"," title=\"Humorous Text Generator with GPT-2\",\n"," description=\"Get a chuckle with AI-generated funny stories! Provide a prompt (or choose one), adjust the desired text length and number of outputs, and let the AI do the rest!\",\n",")\n","\n","# Launch the Gradio app\n","interface.launch()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":646},"id":"Uk5URgnBVEnk","executionInfo":{"status":"ok","timestamp":1720278608955,"user_tz":-330,"elapsed":2848,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"bf410827-ecef-4ccf-80b5-72e8866566dd"},"execution_count":28,"outputs":[{"output_type":"stream","name":"stdout","text":["Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n","\n","Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n","Running on public URL: https://a1805b7b0e14114fa3.gradio.live\n","\n","This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"]},{"output_type":"display_data","data":{"text/plain":["<IPython.core.display.HTML object>"],"text/html":["<div><iframe src=\"https://a1805b7b0e14114fa3.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":[]},"metadata":{},"execution_count":28}]},{"cell_type":"code","source":["generate_text(\"The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.\", max_length=50, num_return_sequences=10)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":750},"id":"nq6Oe1QNb4GW","executionInfo":{"status":"ok","timestamp":1720278975871,"user_tz":-330,"elapsed":48361,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"63744072-62ad-4402-d471-8b394b95f024"},"execution_count":31,"outputs":[{"output_type":"stream","name":"stdout","text":["> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. enemies\n","Should straight you sister so, then 'towle in a hand\n","From the knash slily hereafter\n","Were\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. Cai\n","From the sort will of nature, we lose\n","For this her who,\n","Could sake of the absence of her\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. ladies\n","For which you other else as tender's eyes\n","Or shall this right noble\n","Mark us, come again.\n","\n","\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. bearWe have touch the news.We have who we else use nor\n","She pass'd\n","Nay you, come there have\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. chance which you: why\n","Dear brother but that; go disdain:-- me again\n","Upon it, 'tis it out another\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. marry there have no us\n","which? hang it shall bear alone: our late himself\n","From yesternight in the hearts of\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.you, then that deserves\n","Good sister to have one\n","From earth\n","Uncle-morrow-morrow from the purpose\n"," you\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. senate, make our sour treacherums-in of an ass\n","They that way\n","Did the memory of mytis thus going in\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.Go, make us forth majesty\n","Then shall ' against the soldier: look, my son,\n","Be plain presently; sit you\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. enemies\n","And not the Lord Hastings, then's stinking-time\n","But Romeo\n","Should yet our brought them themselves are upon\n"]},{"output_type":"execute_result","data":{"text/plain":["\"The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. enemies\\nAnd not the Lord Hastings, then's stinking-time\\nBut Romeo\\nShould yet our brought them themselves are upon\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":31}]},{"cell_type":"code","source":["start_tokens ="],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":53},"id":"QiGYSwcXfgeZ","executionInfo":{"status":"ok","timestamp":1720277047339,"user_tz":-330,"elapsed":441,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"10dbc134-7b28-4002-ccc9-63a8eefe09e9"},"execution_count":18,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity...'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":18}]},{"cell_type":"code","source":[],"metadata":{"id":"onodAEVxgH2d"},"execution_count":null,"outputs":[]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.7"}},"nbformat":4,"nbformat_minor":0}
gpt_124M_30thJune2024.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:951d9d1e0f2869d73873484c2f6ad89866831a9e8daac29550b99b9dcaf4eae9
3
+ size 548295018
model.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT-3 Paper
2
+ # add cosing delay
3
+ import os
4
+ import math
5
+ import time
6
+ import inspect
7
+ from dataclasses import dataclass
8
+ import torch
9
+ import torch.nn as nn
10
+ import tiktoken
11
+ from torch.nn import functional as F
12
+
13
+
14
+ class CausalSelfAttention(nn.Module):
15
+
16
+ def __init__(self, config):
17
+ super().__init__()
18
+
19
+ #assertion to ensure the embedding dimension is divisible by the number of heads (important for reshaping later).
20
+ assert config.n_embd % config.n_head == 0
21
+
22
+ # key, query, value projections for all heads, but in a batch. Each vector has the same dimension (C) as the input embedding.
23
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
24
+
25
+
26
+ # output projection find the meaning?
27
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
28
+ self.c_proj.NANGPT_SCALE_INIT = 1
29
+
30
+
31
+
32
+ # regularization
33
+ self.n_head = config.n_head
34
+ self.n_embd = config.n_embd
35
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
36
+
37
+ def forward(self, x):
38
+ # x is tokenised version of input.txt
39
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
40
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
41
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
42
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
43
+ qkv = self.c_attn(x)
44
+ q, k, v = qkv.split(self.n_embd, dim=2)
45
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
46
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
47
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
48
+
49
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
50
+ # att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))# find what is it???
51
+
52
+ # att = F.softmax(att, dim=-1)
53
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
54
+
55
+
56
+
57
+ ## This function combines the dot product, scaling, and softmax operations into a single step.
58
+ y = F.scaled_dot_product_attention(q, k, v, is_causal = True) # Flash attention
59
+
60
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
61
+ # output projection
62
+ y = self.c_proj(y)
63
+ return y
64
+
65
+
66
+ class MLP(nn.Module):
67
+ # MLP (Multi-Layer Perceptron)
68
+ ## This class implements a simple multi-layer perceptron (MLP) sub-module.
69
+ ## It's often used within transformers for non-linear transformations.
70
+
71
+ def __init__(self, config):
72
+ #sqeeze and expand
73
+ super().__init__()
74
+ #c_fc: Projects the input (x) to a dimension four times larger than the embedding dimension (n_embd).
75
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
76
+
77
+ # GELU (Gaussian Error Linear Unit) activation function for non-linearity.
78
+ #Here, an approximate version using tanh is used.
79
+ self.gelu = nn.GELU(approximate='tanh')
80
+
81
+ # Projects the output back to the original embedding dimension (n_embd).
82
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
83
+ self.c_proj.NANOGPT_SCALE_INIT = 1
84
+
85
+ def forward(self, x):
86
+
87
+ #Takes the input (x).
88
+ # Applies the linear layer (c_fc), followed by the GELU activation.
89
+ # Applies the final projection layer (c_proj).
90
+ # Returns the transformed output.
91
+ x = self.c_fc(x)
92
+ x = self.gelu(x)
93
+ x = self.c_proj(x)
94
+ return x
95
+
96
+ class Block(nn.Module):
97
+ # This class combines the CausalSelfAttention layer (explained previously) and the MLP layer to form a single transformer block.
98
+ # The input is processed through the attention layer, followed by layer normalization and an MLP, and
99
+ # then again with layer normalization.
100
+
101
+ def __init__(self, config):
102
+ super().__init__()
103
+
104
+ #ln_1: A layer normalization layer applied before the causal self-attention.
105
+ #attn: An instance of the CausalSelfAttention class (explained previously).
106
+ #mlp: An instance of the MLP class (explained previously).
107
+
108
+ self.ln_1 = nn.LayerNorm(config.n_embd)
109
+ self.attn = CausalSelfAttention(config)
110
+ self.ln_2 = nn.LayerNorm(config.n_embd)
111
+ self.mlp = MLP(config)
112
+
113
+ def forward(self, x):
114
+ # Takes the input (x).
115
+ # Performs a residual connection with the output from the causal self-attention layer (attn), preceded by layer normalization (ln_1).
116
+ # Performs another residual connection with the output from the MLP layer (mlp), preceded by layer normalization (ln_2).
117
+ # Returns the final output after the second residual connection.
118
+ x = x + self.attn(self.ln_1(x))
119
+ x = x + self.mlp(self.ln_2(x))
120
+ return x
121
+
122
+
123
+ @dataclass
124
+ class GPTConfig:
125
+ block_size: int = 1024 # max sequence length
126
+ vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
127
+ n_layer: int = 12 # number of layers
128
+ n_head: int = 12 # number of heads
129
+ n_embd: int = 768 # embedding dimension
130
+
131
+
132
+ class GPT(nn.Module):
133
+
134
+ def __init__(self, config):
135
+ super().__init__()
136
+ self.config = config
137
+
138
+
139
+
140
+ # Creates a transformer module dictionary containing several key components:
141
+ #wte: Word token embedding layer (nn.Embedding). Maps each word index to its corresponding embedding vector.
142
+ #wpe: Positional embedding layer (nn.Embedding). Adds positional information to the word embeddings.
143
+ #h: A module list containing multiple Block instances (explained earlier). These are the core processing units of the transformer.
144
+ #ln_f: Final layer normalization layer (nn.LayerNorm) applied to the output of the transformer blocks.
145
+
146
+ self.transformer = nn.ModuleDict(dict(
147
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
148
+ wpe = nn.Embedding(config.block_size, config.n_embd),
149
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
150
+ ln_f = nn.LayerNorm(config.n_embd),
151
+ ))
152
+
153
+
154
+ #Creates the language modeling head (lm_head), a linear layer that projects the final hidden state from the
155
+ #transformer to the vocabulary size, predicting the next word in the sequence.
156
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
157
+
158
+ # weight sharing Implements weight sharing between the word token embedding layer (wte)
159
+ #and the language modeling head (lm_head). This reduces the number of parameters and encourages
160
+ #the model to learn a meaningful representation for words that can be used for both embedding and prediction.
161
+ self.transformer.wte.weight = self.lm_head.weight
162
+
163
+ # weight initialization
164
+ #Initializes the weights of the model using a custom function (_init_weights).
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, module):
168
+ if isinstance(module, nn.Linear):
169
+ std = 0.02
170
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
171
+ std *= (2 * self.config.n_layer) ** -0.5
172
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
173
+ if module.bias is not None:
174
+ torch.nn.init.zeros_(module.bias)
175
+ elif isinstance(module, nn.Embedding):
176
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
177
+
178
+
179
+
180
+ def forward(self, idx, targets=None):
181
+ # idx is of shape (B, T)
182
+ B, T = idx.size()
183
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
184
+ # forward the token and posisition embeddings
185
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
186
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
187
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
188
+ x = tok_emb + pos_emb
189
+ # forward the blocks of the transformer
190
+ for block in self.transformer.h:
191
+ x = block(x)
192
+ # forward the final layernorm and the classifier
193
+ x = self.transformer.ln_f(x)
194
+ logits = self.lm_head(x) # (B, T, vocab_size)
195
+ loss = None
196
+ if targets is not None:
197
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
198
+ return logits, loss
199
+
200
+ @classmethod
201
+ def from_pretrained(cls, model_type):
202
+ """Loads pretrained GPT-2 model weights from huggingface"""
203
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
204
+ from transformers import GPT2LMHeadModel
205
+ print("loading weights from pretrained gpt: %s" % model_type)
206
+
207
+ # n_layer, n_head and n_embd are determined from model_type
208
+ config_args = {
209
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
210
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
211
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
212
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
213
+ }[model_type]
214
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
215
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
216
+ # create a from-scratch initialized minGPT model
217
+ config = GPTConfig(**config_args)
218
+ model = GPT(config)
219
+ sd = model.state_dict()
220
+ sd_keys = sd.keys()
221
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
222
+
223
+ # init a huggingface/transformers model
224
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
225
+ sd_hf = model_hf.state_dict()
226
+
227
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
228
+ sd_keys_hf = sd_hf.keys()
229
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
230
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
231
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
232
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
233
+ # this means that we have to transpose these weights when we import them
234
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
235
+ for k in sd_keys_hf:
236
+ if any(k.endswith(w) for w in transposed):
237
+ # special treatment for the Conv1D weights we need to transpose
238
+ assert sd_hf[k].shape[::-1] == sd[k].shape
239
+ with torch.no_grad():
240
+ sd[k].copy_(sd_hf[k].t())
241
+ else:
242
+ # vanilla copy over the other parameters
243
+ assert sd_hf[k].shape == sd[k].shape
244
+ with torch.no_grad():
245
+ sd[k].copy_(sd_hf[k])
246
+
247
+ return model
248
+
249
+ def configure_optimizers(self, weight_decay, learning_rate, device_type):
250
+ # start with all of the candidate parameters (that require grad)
251
+ param_dict = {pn: p for pn, p in self.named_parameters()}
252
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
253
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
254
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
255
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
256
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
257
+ optim_groups = [
258
+ {'params': decay_params, 'weight_decay': weight_decay},
259
+ {'params': nodecay_params, 'weight_decay': 0.0}
260
+ ]
261
+ num_decay_params = sum(p.numel() for p in decay_params)
262
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
263
+
264
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
265
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
266
+ # Create AdamW optimizer and use the fused version if it is available
267
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
268
+ use_fused = fused_available and device_type == "cuda"
269
+
270
+ print(f"using fused AdamW: {use_fused}")
271
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
272
+ return optimizer
273
+
274
+ # model = GPT.from_pretrained('gpt2')
275
+
276
+ device = 'cpu'
277
+ if torch.cuda.is_available():
278
+ device = 'cuda'
279
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
280
+ device = "mps"
281
+ print(f"using device: {device}")
282
+
283
+ # SEED
284
+ torch.manual_seed(1337)
285
+ if torch.cuda.is_available():
286
+ torch.cuda.manual_seed(1337)
287
+
288
+ # STOP
289
+ # num_return_sequences = 5
290
+ # max_length = 30
291
+
292
+
293
+
294
+ import tiktoken
295
+
296
+ class DataLoaderLite:
297
+ def __init__(self, B, T):
298
+ self.B = B
299
+ self.T = T
300
+
301
+ # at init load tokens from disk and store them in memory
302
+ with open('input.txt', 'r') as f:
303
+ text = f.read()
304
+ enc = tiktoken.get_encoding('gpt2')
305
+ tokens = enc.encode(text)
306
+ self.tokens = torch.tensor(tokens)
307
+ print(f'loaded {len(self.tokens)} tokens')
308
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
309
+
310
+ # state
311
+ self.current_position = 0
312
+
313
+ def next_batch(self):
314
+ B, T = self.B, self.T
315
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
316
+ x = (buf[:-1]).view(B, T) # inputs
317
+ y = (buf[1:]).view(B, T) # targets
318
+ # advance the position in the tensor
319
+ self.current_position += B*T
320
+ # if loading the next batch would be out of bounds, reset
321
+ if self.current_position + (B * T + 1) > len(self.tokens):
322
+ self.current_position = 0
323
+ return x, y
324
+
325
+ # CHANGES IN CURRENT CODE
326
+ torch.set_float32_matmul_precision('high')
327
+ model = GPT(GPTConfig())
328
+ model.to(device)
329
+ # model = torch.compile(model)
330
+
331
+ # CODE UPDATE HERE
332
+ max_lr = 6e-4
333
+ min_lr = max_lr * 0.1
334
+ # warmup_steps = 100
335
+ # # max_steps = 50
336
+
337
+ def get_lr(it,warmup_steps, max_steps):
338
+ if it < warmup_steps:
339
+ return max_lr * (it + 1) / warmup_steps
340
+ if it > max_steps:
341
+ return min_lr
342
+ decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
343
+ assert 0 <= decay_ratio <=1
344
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
345
+ return min_lr + coeff * (max_lr - min_lr)
346
+
347
+
348
+ # NEW CODE
349
+ import time
350
+ train_loader = DataLoaderLite(B = 8, T = 512)
351
+
352
+ # train_loader = DataLoaderLite(B = B, T = T)
353
+ x, y = train_loader.next_batch()
354
+ x.shape, y.shape
355
+
356
+ def run_train (max_steps = 50 ,warmup_steps = 100, PATH = "/content/drive/MyDrive/S21/gpt_124M.pth" ):
357
+ # optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4, betas=(0.9, 0.95), eps=1e-8)
358
+ optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)
359
+ for step in range(max_steps):
360
+ t0 = time.time()
361
+ x, y = train_loader.next_batch()
362
+ x, y = x.to(device), y.to(device)
363
+ optimizer.zero_grad()
364
+ # NEW CODE ADDED HERE
365
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
366
+ logits, loss = model(x, y)
367
+ loss.backward()
368
+ norm = torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
369
+ # NEW CODE
370
+ lr = get_lr(step, max_steps = 50 ,warmup_steps = 100)
371
+ for param_group in optimizer.param_groups:
372
+ param_group['lr'] = lr
373
+
374
+ optimizer.step()
375
+ torch.cuda.synchronize()
376
+ t1 = time.time()
377
+ dt = (t1 - t0) * 1000
378
+ tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
379
+ print(f'step{step} | loss: {loss.item()} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec: .2f} | norm: {norm:.2f}')
380
+ print(loss)
381
+ torch.save(model.state_dict(), PATH)
382
+ return model
383
+
384
+ def load_fromsaved(PATH = "/content/drive/MyDrive/S21/gpt_124M.pth" ):
385
+
386
+ # Create a new GPT model instance
387
+ model = GPT(GPTConfig())
388
+ model.to(device)
389
+
390
+ # Load the saved weights into the model
391
+ model.load_state_dict(torch.load(PATH))
392
+
393
+
394
+ # Print confirmation message
395
+ print("Loaded model weights from:", PATH)
396
+ model.to(device)
397
+
398
+ return model
399
+
400
+
401
+ def gen_text(model,start_tokens, max_length=100, num_return_sequences=10 ):
402
+ """
403
+ Generates text using the loaded GPT model.
404
+
405
+ Args:
406
+ model: The GPT model to use for generation.
407
+ start_tokens (optional): A list of token IDs to use as the starting prompt.
408
+ max_length: The maximum length of the generated text.
409
+ num_return_sequences: The number of text sequences to generate.
410
+
411
+ Returns:
412
+ None
413
+ """
414
+ decoded_texts = ''
415
+ enc = tiktoken.get_encoding('gpt2')
416
+ tokens = enc.encode(start_tokens)
417
+ tokens = torch.tensor(tokens, dtype= torch.long) # (8,) #check tiktoken app
418
+ tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) # (5, 8)
419
+ x = tokens.to(device)
420
+
421
+ # Set random seeds for consistent generation across runs
422
+ torch.manual_seed(42)
423
+ torch.cuda.manual_seed(42)
424
+
425
+ while x.size(1) < max_length:
426
+ # forward the model to get the logits
427
+ with torch.no_grad():
428
+ logits = model(x)[0] # (B, T, vocab_size)
429
+ # take the logits at the last position
430
+ logits = logits[:, -1, :] # (B, vocab_size)
431
+ # get the probabilities
432
+ probs = F.softmax(logits, dim=-1)
433
+ # do top-k sampling of 50 (huggingface pipeline default)
434
+ # topk_probs here becomes (5, 50), topk_indices is (5, 50)
435
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
436
+ # select a token from the top-k probabilities
437
+ # note: multinomial does not demand the input to sum to 1
438
+ ix = torch.multinomial(topk_probs, 1) # (B, 1)
439
+ # gather the corresponding indices
440
+ xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
441
+ # append to the sequence
442
+ x = torch.cat((x, xcol), dim=1)
443
+
444
+ # print the generated text
445
+ for i in range(num_return_sequences):
446
+ tokens = x[i, :max_length].tolist()
447
+ decoded = enc.decode(tokens)
448
+ print(">", decoded)
449
+ # decoded_texts.append(decoded)
450
+ # # Join all the decoded texts into a single string and print it
451
+ # final_decoded_text = "".join(decoded_texts)
452
+ # print(final_decoded_text)
453
+ # return final_decoded_text
454
+
455
+
456
+
457
+ # def gen_text(model,x = x, max_length = 100, num_return_sequences=10):
458
+ # torch.manual_seed(42)
459
+ # torch.cuda.manual_seed(42)
460
+ # while x.size(1) < max_length:
461
+ # # forward the model to get the logits
462
+ # with torch.no_grad():
463
+ # logits = model(x)[0] # (B, T, vocab_size)
464
+ # # take the logits at the last position
465
+ # logits = logits[:, -1, :] # (B, vocab_size)
466
+ # # get the probabilities
467
+ # probs = F.softmax(logits, dim=-1)
468
+ # # do top-k sampling of 50 (huggingface pipeline default)
469
+ # # topk_probs here becomes (5, 50), topk_indices is (5, 50)
470
+ # topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
471
+ # # select a token from the top-k probabilities
472
+ # # note: multinomial does not demand the input to sum to 1
473
+ # ix = torch.multinomial(topk_probs, 1) # (B, 1)
474
+ # # gather the corresponding indices
475
+ # xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
476
+ # # append to the sequence
477
+ # x = torch.cat((x, xcol), dim=1)
478
+
479
+ # # print the generated text
480
+ # for i in range(num_return_sequences):
481
+ # tokens = x[i, :max_length].tolist()
482
+ # decoded = enc.decode(tokens)
483
+ # print(">", decoded)
training.ipynb ADDED
The diff for this file is too large to render. See raw diff