Spaces:
Paused
Paused
Upload 4 files
Browse files- final_gpt2.ipynb +1 -0
- gpt_124M_30thJune2024.pth +3 -0
- model.py +483 -0
- training.ipynb +0 -0
final_gpt2.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":38258,"status":"ok","timestamp":1720292542191,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"OCNt3VoHt5Nq","outputId":"39403a61-6ca2-4654-a084-0a09fc2ba3fa"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":442,"status":"ok","timestamp":1720292714705,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"HSjbX-JgxmOL","outputId":"17a3fd12-c90f-4865-d5a1-b80ca2d119e1"},"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/MyDrive/S21\n"]}],"source":["import os\n","script_dir = os.path.dirname(\"/content/drive/MyDrive/S21/classnotes.ipynb\")\n","# Change the cwd to the script's directory\n","os.chdir(script_dir)\n","\n","# Now the cwd is set to the directory containing the script (and potentially the file)\n","print(os.getcwd()) # This will print the current working directory"]},{"cell_type":"code","execution_count":3,"metadata":{"executionInfo":{"elapsed":28808,"status":"ok","timestamp":1720292745598,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"7JbXHrA1tWfr","colab":{"base_uri":"https://localhost:8080/"},"outputId":"89581984-6dca-4d27-9c1a-035e675d4fa8"},"outputs":[{"output_type":"stream","name":"stdout","text":["\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m12.3/12.3 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m92.0/92.0 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m318.2/318.2 kB\u001b[0m \u001b[31m28.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m141.1/141.1 kB\u001b[0m \u001b[31m16.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m10.1/10.1 MB\u001b[0m \u001b[31m42.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m62.4/62.4 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m71.9/71.9 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m53.6/53.6 kB\u001b[0m \u001b[31m7.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m307.7/307.7 kB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββ\u001b[0m \u001b[32m341.4/341.4 kB\u001b[0m \u001b[31m37.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m67.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m57.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n","\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25h"]}],"source":["!pip install -q gradio\n","!pip install -q tiktoken"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":442,"status":"ok","timestamp":1720292788608,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"_sOT51G4tOrc"},"outputs":[],"source":["from transformers import GPT2LMHeadModel\n","import tiktoken\n","import torch\n","import gradio as gr\n","import model\n","from model import run_train, gen_text"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":2,"status":"ok","timestamp":1720292789731,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"LGko_SepX3Pv"},"outputs":[],"source":["# Specify a path\n","PATH = \"/content/drive/MyDrive/S21/gpt_124M_30thJune2024.pth\"\n"]},{"cell_type":"code","execution_count":8,"metadata":{"executionInfo":{"elapsed":775,"status":"ok","timestamp":1720292791316,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"FICKsglhZHr3"},"outputs":[],"source":["class GPTConfig:\n"," block_size: int = 1024 # max sequence length\n"," vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token\n"," n_layer: int = 12 # number of layers\n"," n_head: int = 12 # number of heads\n"," n_embd: int = 768 # embedding dimension"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"txAZmr2qYdLH","executionInfo":{"status":"ok","timestamp":1720292814216,"user_tz":-330,"elapsed":22902,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"3f0f8fed-5cd5-4591-fadb-bb0d542f41d2"},"outputs":[{"output_type":"execute_result","data":{"text/plain":["<All keys matched successfully>"]},"metadata":{},"execution_count":9}],"source":["device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","model2 = model.GPT(GPTConfig())\n","model2 = model2.to(device)\n","# model2.load_state_dict(torch.load(PATH),map_location=torch.device('cpu'))\n","model2.load_state_dict(torch.load('/content/drive/MyDrive/S21/gpt_124M_1.pth', map_location=torch.device(device)))\n"]},{"cell_type":"code","execution_count":17,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1720205456686,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"IdQe0IQdY9km"},"outputs":[],"source":[]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":28447,"status":"ok","timestamp":1720275103624,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"qJjvfdWfiORi","outputId":"4c1baccf-638f-427e-c272-1001231c25d3"},"outputs":[{"output_type":"stream","name":"stdout","text":["> The adventurer delved into the dark..., speak there's faithful madward.\n","\n","Now, and most Lady hereafter sent her well; call the foot of such a good,\n","For you,\n","If all from the city butchers\n","By\n","> The adventurer delved into the dark...,\n","For that have you keep it, webusiness and no truth?\n","Sir more sort of the mind\n","As presently you, you have had a merry years\n","That venture-CAMILLO:\n","> The adventurer delved into the dark..., hang it have possible\n","Bestlingied below the conspe;\n","For that come on the fain of no\n","And all, give and match from me\n","Should now upon my worthy weak-women\n","> The adventurer delved into the dark..., take them go approaches\n","thanable; then no worst, hence comes here's sorrow your humblebts:\n","Had been acquainted with such badxth-ower.\n","\n","Sir, from him can\n","> The adventurer delved into the dark..., tell\n","For since that such virtue,\n","For great purpose to the king's another ill-t passion express the best and we shall\n","Which never for me.\n","\n","GLOUCESTER:\n"]}],"source":["start_tokens = \"The adventurer delved into the dark...\" # You can provide a list of token IDs as a starting prompt here\n","max_length = 50 # Maximum length of the generated text\n","num_return_sequences = 5 # Number of text sequences to generate\n","\n","decoded = gen_text(model2, start_tokens, max_length, num_return_sequences)"]},{"cell_type":"code","execution_count":16,"metadata":{"id":"hfbynK20i6NN","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1720206852956,"user_tz":-330,"elapsed":214728,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"e21afa9f-35fc-4700-98cb-4573eed170cd"},"outputs":[{"output_type":"stream","name":"stdout","text":["> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and deign of a wont\n","For now remain\n","For from men and that pluck supper buried-- no comfort,\n","If the king but thence and me me.\n","How she would come, come home;\n","ESC? which you have no intent\n","These two men the self,\n","aeth o'er was a fellow of enough,\n","Master\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and drops\n","Were, men and noble Edward's children\n","Were half i' the blood it,\n","By knave just and all report about\n","She have no;\n","Would our battle through the pity set,\n","the entertainment of the harm\n","He will, for your vowsken feel,\n","Were you; and wowet doth against this which he\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and Capitol\n","And as become blepper\n","By master, from a more dead in,\n","thee\n","shOSPERO, you have no Rome.\n","Go give'd the sad-p'der-'sThen never stood by and in\n","In this business\n","The power of a slave,\n","In the keep your best stock proclaim'd from it out\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and god grant it must be it's\n","Upon men, here\n","Was the shame on eane it made\n","PERDITA enough to me me, take:\n","For it beggar\n","Were you do it out my knaves\n","One two as an necessaryves to leave,\n","Should warm and fear of the bisson diseases\n","the Angelo:\n","The\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me.\n","If, come before there's envy\n","Like bloubt about you have before indeed\n","She'er\n","A shrewd's crown-s made together, and bad daughter,\n","I have made a mortalvish harm hereafter a goods\n","Thou know, a welcome home\n","She was, it;\n","Who from the king; go walk upon me,\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and pant\n","Will; goest indeed we hidness, from himself hope to such were thus Lord till\n","Were not grant of this shows,\n","Those may bades and sl reward\n","This bed, bid them the intent\n","Of Clifford hence\n"," noble is content\n","When our country's daughter,\n","Or both ere lady's an enm; go alone\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; norCor virtuous will perform.'\n","For there they have no master before great, and rascalPeace of it\n","Foraleth it deserved thence\n","She have no, as this young marketew, to me,\n","For it straight here in a sea: what,\n"," quainMost lose\n"," answer you; and wretch out no,\n","Were, and\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and drops\n","For fit it gave-attRather ways; then be this most brve bent\n","But for him up the hand-w.\n","Meth you have his voice of of women gives\n","The thousands that thy not\n","CLEOMAS: when, for as the sound,\n","That have no hast no breast.\n","\n","motionken composition arms,\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me; and Vol Angelo\n","MIR's dead.\n","\n"," power,\n","the hour possessed within, my master within which dined and the faults\n","Yes, be still,\n","My sister's before me me, and most ha me,\n","Cry\n"," use the;\n","When it dream, again's answer 'em dragon deliver ta'er\n","Either neither say you\n","> The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity... me;\n","Indeed,' to a particular's perish\n","Not ang i' the worst of some prepare under a woman in,\n","Which you have set down i'en the field\n","To take the horses'd me, and by the better-night\n","And bring Baptista enough to give unto like to be to take at it be\n","When it you'll give me\n","\n"]}],"source":["start_tokens = \"The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity...\" # You can provide a list of token IDs as a starting prompt here\n","max_length = 100 # Maximum length of the generated text\n","num_return_sequences = 10 # Number of text sequences to generate\n","\n","gen_text(model2, start_tokens, max_length, num_return_sequences)\n"]},{"cell_type":"code","execution_count":13,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":646},"executionInfo":{"elapsed":1189,"status":"ok","timestamp":1720276368008,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"},"user_tz":-330},"id":"1u3NJcMouwsx","outputId":"4197a764-cbe0-4115-a3e9-c2dda8a4d515"},"outputs":[{"output_type":"stream","name":"stdout","text":["Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n","\n","Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n","Running on public URL: https://e6c7439a4e347d91e9.gradio.live\n","\n","This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"]},{"output_type":"display_data","data":{"text/plain":["<IPython.core.display.HTML object>"],"text/html":["<div><iframe src=\"https://e6c7439a4e347d91e9.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":[]},"metadata":{},"execution_count":13}],"source":["import gradio as gr\n","from transformers import GPT2Tokenizer, AutoModelForCausalLM\n","\n","\n","start_tokens = \"The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity...\" # You can provide a list of token IDs as a starting prompt here\n","max_length = 100 # Maximum length of the generated text\n","num_return_sequences = 10 # Number of text sequences to generate\n","\n","# gen_text(model2, start_tokens, max_length, num_return_sequences)\n","\n","# Define generation function\n","def generate_text(prompt):\n"," start_tokens = prompt\n"," output = gen_text(model2, start_tokens, max_length, num_return_sequences)\n"," #return tokenizer.decode(output[0], skip_special_tokens=True)\n"," return output\n","\n","# Humorous prompt options\n","prompts = [\n"," \"Karen, armed with a coupon...\",\n"," \"Gary the goldfish, tired...\",\n"," \"The local news team received...\",\n"," \"In a shocking turn of events...\",\n"," \"The kingdom's annual jousting...\",\n"," \"The sentient toaster, tired...\",\n"," \"Feeling underappreciated...\",\n"," \"The fortune cookie factory...\"\n","]\n","\n","# Gradio interface with dropdown for prompt selection\n","interface = gr.Interface(\n"," fn=generate_text,\n"," inputs=gr.Dropdown(choices=prompts, label=\"Humorous Prompt\"),\n"," outputs=\"text\",\n"," title=\"Humorous Text Generator with GPT-2\",\n"," description=\"Get a chuckle with AI-generated humorous stories based on your chosen prompt.\"\n",")\n","\n","# Launch the Gradio app\n","interface.launch()"]},{"cell_type":"code","source":["# Define generation function\n","def generate_text(prompt, max_length=50, num_return_sequences=10):\n"," \"\"\"Generates humorous text using the GPT-2 model based on the provided prompt and user-specified parameters.\n","\n"," Args:\n"," prompt (str): The starting text for the generation.\n"," max_length (int, optional): The maximum length of the generated text. Defaults to 100.\n"," num_return_sequences (int, optional): The number of different text sequences to generate. Defaults to 10.\n","\n"," Returns:\n"," list: A list of generated humorous text sequences.\n"," \"\"\"\n","\n"," start_tokens = prompt\n"," # generated_texts = model.generate(\n"," # inputs, max_length=max_length, num_return_sequences=num_return_sequences\n"," # )\n"," generated_texts = gen_text(model2, start_tokens, max_length, num_return_sequences)\n"," return generated_texts\n","\n","# Humorous prompt options\n","prompts = [\n"," \"The automatic doors at the grocery store, tired of people holding them open for conversations, developed a mischievous sense of humor.\",\n"," \"The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.\",\n"," \"The fridge goes on a hunger strike\",\n"," \"A colony of ants, inspired by a motivational poster, embarked on a quest to climb the tallest tree in the garden.\",\n"," \"A particularly chatty parrot accidentally spilled the villain's evil plan during a casual conversation with the local mailman.\",\n"," \"TA squirrel declares war on the birdfeeder...\",\n"," \"The refrigerator, overflowing with forgotten groceries, staged a silent protest, refusing to cool anything until some order was restored.\",\n"," \"A fortune cookie predicts world domination\"\n","]\n","\n","# Gradio interface with user inputs and dropdown for prompt selection\n","interface = gr.Interface(\n"," fn=generate_text,\n"," inputs=[\n"," gr.Dropdown(choices=prompts, label=\"Pre-defined Prompt\"),\n"," gr.Slider(minimum=10, maximum=200, label=\"Max Text Length\", value=100, step = 1),\n"," gr.Slider(minimum=1, maximum=20, label=\"Number of Outputs\", value=10,step = 1)\n"," ],\n"," outputs=\"text\",\n"," title=\"Humorous Text Generator with GPT-2\",\n"," description=\"Get a chuckle with AI-generated funny stories! Provide a prompt (or choose one), adjust the desired text length and number of outputs, and let the AI do the rest!\",\n",")\n","\n","# Launch the Gradio app\n","interface.launch()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":646},"id":"Uk5URgnBVEnk","executionInfo":{"status":"ok","timestamp":1720278608955,"user_tz":-330,"elapsed":2848,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"bf410827-ecef-4ccf-80b5-72e8866566dd"},"execution_count":28,"outputs":[{"output_type":"stream","name":"stdout","text":["Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n","\n","Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n","Running on public URL: https://a1805b7b0e14114fa3.gradio.live\n","\n","This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"]},{"output_type":"display_data","data":{"text/plain":["<IPython.core.display.HTML object>"],"text/html":["<div><iframe src=\"https://a1805b7b0e14114fa3.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":[]},"metadata":{},"execution_count":28}]},{"cell_type":"code","source":["generate_text(\"The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.\", max_length=50, num_return_sequences=10)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":750},"id":"nq6Oe1QNb4GW","executionInfo":{"status":"ok","timestamp":1720278975871,"user_tz":-330,"elapsed":48361,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"63744072-62ad-4402-d471-8b394b95f024"},"execution_count":31,"outputs":[{"output_type":"stream","name":"stdout","text":["> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. enemies\n","Should straight you sister so, then 'towle in a hand\n","From the knash slily hereafter\n","Were\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. Cai\n","From the sort will of nature, we lose\n","For this her who,\n","Could sake of the absence of her\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. ladies\n","For which you other else as tender's eyes\n","Or shall this right noble\n","Mark us, come again.\n","\n","\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. bearWe have touch the news.We have who we else use nor\n","She pass'd\n","Nay you, come there have\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. chance which you: why\n","Dear brother but that; go disdain:-- me again\n","Upon it, 'tis it out another\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. marry there have no us\n","which? hang it shall bear alone: our late himself\n","From yesternight in the hearts of\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.you, then that deserves\n","Good sister to have one\n","From earth\n","Uncle-morrow-morrow from the purpose\n"," you\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. senate, make our sour treacherums-in of an ass\n","They that way\n","Did the memory of mytis thus going in\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside.Go, make us forth majesty\n","Then shall ' against the soldier: look, my son,\n","Be plain presently; sit you\n","> The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. enemies\n","And not the Lord Hastings, then's stinking-time\n","But Romeo\n","Should yet our brought them themselves are upon\n"]},{"output_type":"execute_result","data":{"text/plain":["\"The self-driving car, fed up with rush hour traffic, decided to take a scenic detour through the countryside. enemies\\nAnd not the Lord Hastings, then's stinking-time\\nBut Romeo\\nShould yet our brought them themselves are upon\""],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":31}]},{"cell_type":"code","source":["start_tokens ="],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":53},"id":"QiGYSwcXfgeZ","executionInfo":{"status":"ok","timestamp":1720277047339,"user_tz":-330,"elapsed":441,"user":{"displayName":"Seema Goel","userId":"06434468664332377904"}},"outputId":"10dbc134-7b28-4002-ccc9-63a8eefe09e9"},"execution_count":18,"outputs":[{"output_type":"execute_result","data":{"text/plain":["'The prophecy was clear: only the chosen one, with the perfect social media profile, could defeat the forces of online negativity...'"],"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"}},"metadata":{},"execution_count":18}]},{"cell_type":"code","source":[],"metadata":{"id":"onodAEVxgH2d"},"execution_count":null,"outputs":[]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.7"}},"nbformat":4,"nbformat_minor":0}
|
gpt_124M_30thJune2024.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:951d9d1e0f2869d73873484c2f6ad89866831a9e8daac29550b99b9dcaf4eae9
|
3 |
+
size 548295018
|
model.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GPT-3 Paper
|
2 |
+
# add cosing delay
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
import time
|
6 |
+
import inspect
|
7 |
+
from dataclasses import dataclass
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import tiktoken
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class CausalSelfAttention(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, config):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
#assertion to ensure the embedding dimension is divisible by the number of heads (important for reshaping later).
|
20 |
+
assert config.n_embd % config.n_head == 0
|
21 |
+
|
22 |
+
# key, query, value projections for all heads, but in a batch. Each vector has the same dimension (C) as the input embedding.
|
23 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
24 |
+
|
25 |
+
|
26 |
+
# output projection find the meaning?
|
27 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
28 |
+
self.c_proj.NANGPT_SCALE_INIT = 1
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
# regularization
|
33 |
+
self.n_head = config.n_head
|
34 |
+
self.n_embd = config.n_embd
|
35 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
# x is tokenised version of input.txt
|
39 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
40 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
41 |
+
# nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
|
42 |
+
# e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
|
43 |
+
qkv = self.c_attn(x)
|
44 |
+
q, k, v = qkv.split(self.n_embd, dim=2)
|
45 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
46 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
47 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
48 |
+
|
49 |
+
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
50 |
+
# att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))# find what is it???
|
51 |
+
|
52 |
+
# att = F.softmax(att, dim=-1)
|
53 |
+
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
## This function combines the dot product, scaling, and softmax operations into a single step.
|
58 |
+
y = F.scaled_dot_product_attention(q, k, v, is_causal = True) # Flash attention
|
59 |
+
|
60 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
61 |
+
# output projection
|
62 |
+
y = self.c_proj(y)
|
63 |
+
return y
|
64 |
+
|
65 |
+
|
66 |
+
class MLP(nn.Module):
|
67 |
+
# MLP (Multi-Layer Perceptron)
|
68 |
+
## This class implements a simple multi-layer perceptron (MLP) sub-module.
|
69 |
+
## It's often used within transformers for non-linear transformations.
|
70 |
+
|
71 |
+
def __init__(self, config):
|
72 |
+
#sqeeze and expand
|
73 |
+
super().__init__()
|
74 |
+
#c_fc: Projects the input (x) to a dimension four times larger than the embedding dimension (n_embd).
|
75 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
76 |
+
|
77 |
+
# GELU (Gaussian Error Linear Unit) activation function for non-linearity.
|
78 |
+
#Here, an approximate version using tanh is used.
|
79 |
+
self.gelu = nn.GELU(approximate='tanh')
|
80 |
+
|
81 |
+
# Projects the output back to the original embedding dimension (n_embd).
|
82 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
83 |
+
self.c_proj.NANOGPT_SCALE_INIT = 1
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
|
87 |
+
#Takes the input (x).
|
88 |
+
# Applies the linear layer (c_fc), followed by the GELU activation.
|
89 |
+
# Applies the final projection layer (c_proj).
|
90 |
+
# Returns the transformed output.
|
91 |
+
x = self.c_fc(x)
|
92 |
+
x = self.gelu(x)
|
93 |
+
x = self.c_proj(x)
|
94 |
+
return x
|
95 |
+
|
96 |
+
class Block(nn.Module):
|
97 |
+
# This class combines the CausalSelfAttention layer (explained previously) and the MLP layer to form a single transformer block.
|
98 |
+
# The input is processed through the attention layer, followed by layer normalization and an MLP, and
|
99 |
+
# then again with layer normalization.
|
100 |
+
|
101 |
+
def __init__(self, config):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
#ln_1: A layer normalization layer applied before the causal self-attention.
|
105 |
+
#attn: An instance of the CausalSelfAttention class (explained previously).
|
106 |
+
#mlp: An instance of the MLP class (explained previously).
|
107 |
+
|
108 |
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
109 |
+
self.attn = CausalSelfAttention(config)
|
110 |
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
111 |
+
self.mlp = MLP(config)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
# Takes the input (x).
|
115 |
+
# Performs a residual connection with the output from the causal self-attention layer (attn), preceded by layer normalization (ln_1).
|
116 |
+
# Performs another residual connection with the output from the MLP layer (mlp), preceded by layer normalization (ln_2).
|
117 |
+
# Returns the final output after the second residual connection.
|
118 |
+
x = x + self.attn(self.ln_1(x))
|
119 |
+
x = x + self.mlp(self.ln_2(x))
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class GPTConfig:
|
125 |
+
block_size: int = 1024 # max sequence length
|
126 |
+
vocab_size: int = 50304 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
|
127 |
+
n_layer: int = 12 # number of layers
|
128 |
+
n_head: int = 12 # number of heads
|
129 |
+
n_embd: int = 768 # embedding dimension
|
130 |
+
|
131 |
+
|
132 |
+
class GPT(nn.Module):
|
133 |
+
|
134 |
+
def __init__(self, config):
|
135 |
+
super().__init__()
|
136 |
+
self.config = config
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
# Creates a transformer module dictionary containing several key components:
|
141 |
+
#wte: Word token embedding layer (nn.Embedding). Maps each word index to its corresponding embedding vector.
|
142 |
+
#wpe: Positional embedding layer (nn.Embedding). Adds positional information to the word embeddings.
|
143 |
+
#h: A module list containing multiple Block instances (explained earlier). These are the core processing units of the transformer.
|
144 |
+
#ln_f: Final layer normalization layer (nn.LayerNorm) applied to the output of the transformer blocks.
|
145 |
+
|
146 |
+
self.transformer = nn.ModuleDict(dict(
|
147 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
148 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
149 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
150 |
+
ln_f = nn.LayerNorm(config.n_embd),
|
151 |
+
))
|
152 |
+
|
153 |
+
|
154 |
+
#Creates the language modeling head (lm_head), a linear layer that projects the final hidden state from the
|
155 |
+
#transformer to the vocabulary size, predicting the next word in the sequence.
|
156 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
157 |
+
|
158 |
+
# weight sharing Implements weight sharing between the word token embedding layer (wte)
|
159 |
+
#and the language modeling head (lm_head). This reduces the number of parameters and encourages
|
160 |
+
#the model to learn a meaningful representation for words that can be used for both embedding and prediction.
|
161 |
+
self.transformer.wte.weight = self.lm_head.weight
|
162 |
+
|
163 |
+
# weight initialization
|
164 |
+
#Initializes the weights of the model using a custom function (_init_weights).
|
165 |
+
self.apply(self._init_weights)
|
166 |
+
|
167 |
+
def _init_weights(self, module):
|
168 |
+
if isinstance(module, nn.Linear):
|
169 |
+
std = 0.02
|
170 |
+
if hasattr(module, 'NANGPT_SCALE_INIT'):
|
171 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
172 |
+
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
|
173 |
+
if module.bias is not None:
|
174 |
+
torch.nn.init.zeros_(module.bias)
|
175 |
+
elif isinstance(module, nn.Embedding):
|
176 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
def forward(self, idx, targets=None):
|
181 |
+
# idx is of shape (B, T)
|
182 |
+
B, T = idx.size()
|
183 |
+
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
|
184 |
+
# forward the token and posisition embeddings
|
185 |
+
pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
|
186 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
|
187 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
|
188 |
+
x = tok_emb + pos_emb
|
189 |
+
# forward the blocks of the transformer
|
190 |
+
for block in self.transformer.h:
|
191 |
+
x = block(x)
|
192 |
+
# forward the final layernorm and the classifier
|
193 |
+
x = self.transformer.ln_f(x)
|
194 |
+
logits = self.lm_head(x) # (B, T, vocab_size)
|
195 |
+
loss = None
|
196 |
+
if targets is not None:
|
197 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
198 |
+
return logits, loss
|
199 |
+
|
200 |
+
@classmethod
|
201 |
+
def from_pretrained(cls, model_type):
|
202 |
+
"""Loads pretrained GPT-2 model weights from huggingface"""
|
203 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
204 |
+
from transformers import GPT2LMHeadModel
|
205 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
206 |
+
|
207 |
+
# n_layer, n_head and n_embd are determined from model_type
|
208 |
+
config_args = {
|
209 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
210 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
211 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
212 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
213 |
+
}[model_type]
|
214 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
215 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
216 |
+
# create a from-scratch initialized minGPT model
|
217 |
+
config = GPTConfig(**config_args)
|
218 |
+
model = GPT(config)
|
219 |
+
sd = model.state_dict()
|
220 |
+
sd_keys = sd.keys()
|
221 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
222 |
+
|
223 |
+
# init a huggingface/transformers model
|
224 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
225 |
+
sd_hf = model_hf.state_dict()
|
226 |
+
|
227 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
228 |
+
sd_keys_hf = sd_hf.keys()
|
229 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
230 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
231 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
232 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
233 |
+
# this means that we have to transpose these weights when we import them
|
234 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
235 |
+
for k in sd_keys_hf:
|
236 |
+
if any(k.endswith(w) for w in transposed):
|
237 |
+
# special treatment for the Conv1D weights we need to transpose
|
238 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
239 |
+
with torch.no_grad():
|
240 |
+
sd[k].copy_(sd_hf[k].t())
|
241 |
+
else:
|
242 |
+
# vanilla copy over the other parameters
|
243 |
+
assert sd_hf[k].shape == sd[k].shape
|
244 |
+
with torch.no_grad():
|
245 |
+
sd[k].copy_(sd_hf[k])
|
246 |
+
|
247 |
+
return model
|
248 |
+
|
249 |
+
def configure_optimizers(self, weight_decay, learning_rate, device_type):
|
250 |
+
# start with all of the candidate parameters (that require grad)
|
251 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
252 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
253 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
254 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
255 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
256 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
257 |
+
optim_groups = [
|
258 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
259 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
260 |
+
]
|
261 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
262 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
263 |
+
|
264 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
265 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
266 |
+
# Create AdamW optimizer and use the fused version if it is available
|
267 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
268 |
+
use_fused = fused_available and device_type == "cuda"
|
269 |
+
|
270 |
+
print(f"using fused AdamW: {use_fused}")
|
271 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
|
272 |
+
return optimizer
|
273 |
+
|
274 |
+
# model = GPT.from_pretrained('gpt2')
|
275 |
+
|
276 |
+
device = 'cpu'
|
277 |
+
if torch.cuda.is_available():
|
278 |
+
device = 'cuda'
|
279 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
280 |
+
device = "mps"
|
281 |
+
print(f"using device: {device}")
|
282 |
+
|
283 |
+
# SEED
|
284 |
+
torch.manual_seed(1337)
|
285 |
+
if torch.cuda.is_available():
|
286 |
+
torch.cuda.manual_seed(1337)
|
287 |
+
|
288 |
+
# STOP
|
289 |
+
# num_return_sequences = 5
|
290 |
+
# max_length = 30
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
import tiktoken
|
295 |
+
|
296 |
+
class DataLoaderLite:
|
297 |
+
def __init__(self, B, T):
|
298 |
+
self.B = B
|
299 |
+
self.T = T
|
300 |
+
|
301 |
+
# at init load tokens from disk and store them in memory
|
302 |
+
with open('input.txt', 'r') as f:
|
303 |
+
text = f.read()
|
304 |
+
enc = tiktoken.get_encoding('gpt2')
|
305 |
+
tokens = enc.encode(text)
|
306 |
+
self.tokens = torch.tensor(tokens)
|
307 |
+
print(f'loaded {len(self.tokens)} tokens')
|
308 |
+
print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
|
309 |
+
|
310 |
+
# state
|
311 |
+
self.current_position = 0
|
312 |
+
|
313 |
+
def next_batch(self):
|
314 |
+
B, T = self.B, self.T
|
315 |
+
buf = self.tokens[self.current_position: self.current_position + B * T + 1]
|
316 |
+
x = (buf[:-1]).view(B, T) # inputs
|
317 |
+
y = (buf[1:]).view(B, T) # targets
|
318 |
+
# advance the position in the tensor
|
319 |
+
self.current_position += B*T
|
320 |
+
# if loading the next batch would be out of bounds, reset
|
321 |
+
if self.current_position + (B * T + 1) > len(self.tokens):
|
322 |
+
self.current_position = 0
|
323 |
+
return x, y
|
324 |
+
|
325 |
+
# CHANGES IN CURRENT CODE
|
326 |
+
torch.set_float32_matmul_precision('high')
|
327 |
+
model = GPT(GPTConfig())
|
328 |
+
model.to(device)
|
329 |
+
# model = torch.compile(model)
|
330 |
+
|
331 |
+
# CODE UPDATE HERE
|
332 |
+
max_lr = 6e-4
|
333 |
+
min_lr = max_lr * 0.1
|
334 |
+
# warmup_steps = 100
|
335 |
+
# # max_steps = 50
|
336 |
+
|
337 |
+
def get_lr(it,warmup_steps, max_steps):
|
338 |
+
if it < warmup_steps:
|
339 |
+
return max_lr * (it + 1) / warmup_steps
|
340 |
+
if it > max_steps:
|
341 |
+
return min_lr
|
342 |
+
decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
|
343 |
+
assert 0 <= decay_ratio <=1
|
344 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
345 |
+
return min_lr + coeff * (max_lr - min_lr)
|
346 |
+
|
347 |
+
|
348 |
+
# NEW CODE
|
349 |
+
import time
|
350 |
+
train_loader = DataLoaderLite(B = 8, T = 512)
|
351 |
+
|
352 |
+
# train_loader = DataLoaderLite(B = B, T = T)
|
353 |
+
x, y = train_loader.next_batch()
|
354 |
+
x.shape, y.shape
|
355 |
+
|
356 |
+
def run_train (max_steps = 50 ,warmup_steps = 100, PATH = "/content/drive/MyDrive/S21/gpt_124M.pth" ):
|
357 |
+
# optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4, betas=(0.9, 0.95), eps=1e-8)
|
358 |
+
optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device_type=device)
|
359 |
+
for step in range(max_steps):
|
360 |
+
t0 = time.time()
|
361 |
+
x, y = train_loader.next_batch()
|
362 |
+
x, y = x.to(device), y.to(device)
|
363 |
+
optimizer.zero_grad()
|
364 |
+
# NEW CODE ADDED HERE
|
365 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
366 |
+
logits, loss = model(x, y)
|
367 |
+
loss.backward()
|
368 |
+
norm = torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
|
369 |
+
# NEW CODE
|
370 |
+
lr = get_lr(step, max_steps = 50 ,warmup_steps = 100)
|
371 |
+
for param_group in optimizer.param_groups:
|
372 |
+
param_group['lr'] = lr
|
373 |
+
|
374 |
+
optimizer.step()
|
375 |
+
torch.cuda.synchronize()
|
376 |
+
t1 = time.time()
|
377 |
+
dt = (t1 - t0) * 1000
|
378 |
+
tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
|
379 |
+
print(f'step{step} | loss: {loss.item()} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec: .2f} | norm: {norm:.2f}')
|
380 |
+
print(loss)
|
381 |
+
torch.save(model.state_dict(), PATH)
|
382 |
+
return model
|
383 |
+
|
384 |
+
def load_fromsaved(PATH = "/content/drive/MyDrive/S21/gpt_124M.pth" ):
|
385 |
+
|
386 |
+
# Create a new GPT model instance
|
387 |
+
model = GPT(GPTConfig())
|
388 |
+
model.to(device)
|
389 |
+
|
390 |
+
# Load the saved weights into the model
|
391 |
+
model.load_state_dict(torch.load(PATH))
|
392 |
+
|
393 |
+
|
394 |
+
# Print confirmation message
|
395 |
+
print("Loaded model weights from:", PATH)
|
396 |
+
model.to(device)
|
397 |
+
|
398 |
+
return model
|
399 |
+
|
400 |
+
|
401 |
+
def gen_text(model,start_tokens, max_length=100, num_return_sequences=10 ):
|
402 |
+
"""
|
403 |
+
Generates text using the loaded GPT model.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
model: The GPT model to use for generation.
|
407 |
+
start_tokens (optional): A list of token IDs to use as the starting prompt.
|
408 |
+
max_length: The maximum length of the generated text.
|
409 |
+
num_return_sequences: The number of text sequences to generate.
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
None
|
413 |
+
"""
|
414 |
+
decoded_texts = ''
|
415 |
+
enc = tiktoken.get_encoding('gpt2')
|
416 |
+
tokens = enc.encode(start_tokens)
|
417 |
+
tokens = torch.tensor(tokens, dtype= torch.long) # (8,) #check tiktoken app
|
418 |
+
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1) # (5, 8)
|
419 |
+
x = tokens.to(device)
|
420 |
+
|
421 |
+
# Set random seeds for consistent generation across runs
|
422 |
+
torch.manual_seed(42)
|
423 |
+
torch.cuda.manual_seed(42)
|
424 |
+
|
425 |
+
while x.size(1) < max_length:
|
426 |
+
# forward the model to get the logits
|
427 |
+
with torch.no_grad():
|
428 |
+
logits = model(x)[0] # (B, T, vocab_size)
|
429 |
+
# take the logits at the last position
|
430 |
+
logits = logits[:, -1, :] # (B, vocab_size)
|
431 |
+
# get the probabilities
|
432 |
+
probs = F.softmax(logits, dim=-1)
|
433 |
+
# do top-k sampling of 50 (huggingface pipeline default)
|
434 |
+
# topk_probs here becomes (5, 50), topk_indices is (5, 50)
|
435 |
+
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
|
436 |
+
# select a token from the top-k probabilities
|
437 |
+
# note: multinomial does not demand the input to sum to 1
|
438 |
+
ix = torch.multinomial(topk_probs, 1) # (B, 1)
|
439 |
+
# gather the corresponding indices
|
440 |
+
xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
|
441 |
+
# append to the sequence
|
442 |
+
x = torch.cat((x, xcol), dim=1)
|
443 |
+
|
444 |
+
# print the generated text
|
445 |
+
for i in range(num_return_sequences):
|
446 |
+
tokens = x[i, :max_length].tolist()
|
447 |
+
decoded = enc.decode(tokens)
|
448 |
+
print(">", decoded)
|
449 |
+
# decoded_texts.append(decoded)
|
450 |
+
# # Join all the decoded texts into a single string and print it
|
451 |
+
# final_decoded_text = "".join(decoded_texts)
|
452 |
+
# print(final_decoded_text)
|
453 |
+
# return final_decoded_text
|
454 |
+
|
455 |
+
|
456 |
+
|
457 |
+
# def gen_text(model,x = x, max_length = 100, num_return_sequences=10):
|
458 |
+
# torch.manual_seed(42)
|
459 |
+
# torch.cuda.manual_seed(42)
|
460 |
+
# while x.size(1) < max_length:
|
461 |
+
# # forward the model to get the logits
|
462 |
+
# with torch.no_grad():
|
463 |
+
# logits = model(x)[0] # (B, T, vocab_size)
|
464 |
+
# # take the logits at the last position
|
465 |
+
# logits = logits[:, -1, :] # (B, vocab_size)
|
466 |
+
# # get the probabilities
|
467 |
+
# probs = F.softmax(logits, dim=-1)
|
468 |
+
# # do top-k sampling of 50 (huggingface pipeline default)
|
469 |
+
# # topk_probs here becomes (5, 50), topk_indices is (5, 50)
|
470 |
+
# topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
|
471 |
+
# # select a token from the top-k probabilities
|
472 |
+
# # note: multinomial does not demand the input to sum to 1
|
473 |
+
# ix = torch.multinomial(topk_probs, 1) # (B, 1)
|
474 |
+
# # gather the corresponding indices
|
475 |
+
# xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
|
476 |
+
# # append to the sequence
|
477 |
+
# x = torch.cat((x, xcol), dim=1)
|
478 |
+
|
479 |
+
# # print the generated text
|
480 |
+
# for i in range(num_return_sequences):
|
481 |
+
# tokens = x[i, :max_length].tolist()
|
482 |
+
# decoded = enc.decode(tokens)
|
483 |
+
# print(">", decoded)
|
training.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|