{ "cells": [ { "cell_type": "markdown", "id": "c499a5c3-0244-41c4-9947-e166206204e2", "metadata": {}, "source": [ "# 3.5 回归类任务" ] }, { "cell_type": "markdown", "id": "4678171b-bbc8-49dd-ad04-48f5ef89b45e", "metadata": {}, "source": [ "GPT-2 原本是设计用于生成自然语言的模型,但通过适当的调整和微调,它也可以用于回归任务,例如预测连续值。\n", "\n", "使用 GPT-2 进行回归问题的解决,可以将回归问题转化为自回归语言模型任务。GPT-2 原本是设计用于生成自然语言的模型,但通过适当的调整和微调,它也可以用于回归任务,例如预测连续值(如情感评分、价格预测等)。\n", "\n", "---\n", "\n", "### **1. 使用 GPT-2 做回归的核心思路**\n", "\n", "1. **调整输出层**:\n", " - 默认情况下,GPT-2 的输出是一个词汇表大小的概率分布,用于预测下一个 token。\n", " - 对于回归问题,可以将模型的最后一层替换为一个线性层,使得输出变为一个标量或多个连续值。\n", " - gpt2的huggingface实现中,可以简单设置1个分类的分类header,实现回归预测。\n", "\n", "2. **损失函数**:\n", " - 对于回归问题,使用均方误差(MSE)或均绝对误差(MAE)作为损失函数,而不是分类任务中常用的交叉熵。\n", "\n", "3. **输入格式**:\n", " - 输入数据仍然是文本,可以通过特定的模板形式加入上下文信息。\n", "\n", "---\n", "\n", "### **2. GPT-2 回归任务的实现步骤**\n", "\n", "#### **(1)加载基础模型**\n", "\n", "从 Hugging Face Transformers 库加载 GPT-2 模型和分词器,并调整其配置以适应回归任务。\n", "\n", "```python\n", "from transformers import GPT2Tokenizer, GPT2Model, GPT2Config, AutoModelForSequenceClassification\n", "\n", "# 加载分词器\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", "\n", "# 调整模型配置,num_labels=1 表示回归任务\n", "config = GPT2Config.from_pretrained(\"gpt2\", num_labels=1)\n", "\n", "# 加载模型,增加回归输出\n", "model = AutoModelForSequenceClassification.from_pretrained(\"gpt2\", config=config)\n", "```\n", "\n", "---\n", "\n", "### **3. 课程数据集**\n", "\n", "本例程使用了蛋白质稳定性分析的数据集,也就是一个蛋白质序列,对应一个float的数值,做回归预测分析。\n", "\n", "**蛋白质稳定性分析**是研究蛋白质在不同条件下保持其结构和功能的能力的过程。蛋白质稳定性是生物化学和生物技术领域的重要课题,影响着蛋白质的折叠、功能执行、以及在应用中的可用性(如工业酶、药物开发等)。\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "1e8c0f86-af78-43e1-8db4-e2a2ea22f815", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"\\nimport os\\n\\n# 设置环境变量, autodl专区 其他idc\\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\\n\\n# 打印环境变量以确认设置成功\\nprint(os.environ.get('HF_ENDPOINT'))\\n\"" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import subprocess\n", "import os\n", "# 设置环境变量, autodl一般区域\n", "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n", "output = result.stdout\n", "for line in output.splitlines():\n", " if '=' in line:\n", " var, value = line.split('=', 1)\n", " os.environ[var] = value\n", "\n", "\"\"\"\n", "import os\n", "\n", "# 设置环境变量, autodl专区 其他idc\n", "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n", "\n", "# 打印环境变量以确认设置成功\n", "print(os.environ.get('HF_ENDPOINT'))\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "c51a8d69-9a36-47e7-8084-f64e6a72e4f7", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModel\n", "from tokenizers import Tokenizer\n", "from transformers import GPT2LMHeadModel, AutoConfig,GPT2Tokenizer\n", "from transformers import AutoModelForSequenceClassification\n", "from transformers import DataCollatorWithPadding" ] }, { "cell_type": "code", "execution_count": 3, "id": "a5aeb7c1-2d2a-4f57-ad8c-659613870e59", "metadata": {}, "outputs": [], "source": [ "#set tokenizer\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"dnagpt/gene_eng_gpt2_v0\")\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": 4, "id": "ad0c19cd-96a5-463e-8b7d-439646fef429", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at dnagpt/gene_eng_gpt2_v0 and are newly initialized: ['score.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "#set model\n", "model = AutoModelForSequenceClassification.from_pretrained('dnagpt/gene_eng_gpt2_v0', num_labels=1)\n", "model.config.pad_token_id = model.config.eos_token_id" ] }, { "cell_type": "code", "execution_count": 5, "id": "8c48cb0a-6142-4afc-823e-08fb33f74222", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['seq_id', 'seq_type', 'seq', 'label'],\n", " num_rows: 62079\n", " })\n", " test: Dataset({\n", " features: ['seq_id', 'seq_type', 'seq', 'label'],\n", " num_rows: 6898\n", " })\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "# 1. load ~11k samples from promoters prediction dataset\n", "dataset = load_dataset(\"csv\", data_files=\"data/protein_stab.csv\")['train'].train_test_split(test_size=0.1)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 6, "id": "685dd025-f00a-4869-bc30-9843c77b6d8a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'seq_id': 'train_prot_32672',\n", " 'seq_type': 'prot',\n", " 'seq': 'FYRLIIFKYPDYIDTYLRLAAIAKEKNNLQLSIEGNGSGGNGSGGNGSGN',\n", " 'label': 0.7599999904632561}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[\"train\"][0]" ] }, { "cell_type": "code", "execution_count": 7, "id": "6e10dbbb-73ef-4b67-8290-77f8896298f5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "datasets mean token lenght 17.24006958538707 min token length 12 max token length 35\n" ] } ], "source": [ "token_len_list = []\n", "for item in dataset[\"test\"]:\n", " inputs = tokenizer.tokenize(item[\"seq\"])\n", " token_len_list.append( len(inputs) )\n", "\n", "mean_len = sum(token_len_list)/len(token_len_list)\n", "min_len = min(token_len_list)\n", "max_len = max(token_len_list)\n", "\n", "print(\"datasets \", \"mean token lenght\", mean_len, \"min token length\", min_len, \"max token length\", max_len)" ] }, { "cell_type": "code", "execution_count": 25, "id": "ac58b5b4-bff0-404d-bcf5-2b93db2b37c0", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "419cce8c5ba249ac8c8773dd2d69992d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/62079 [00:00\n", " \n", " \n", " [30987/31040 1:00:56 < 00:06, 8.47 it/s, Epoch 9.98/10]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossRmse
10.0446000.1634620.163462
20.0419000.1579000.157900
30.0377000.1597240.159724
40.0317000.1576860.157686
50.0288000.1571240.157124
60.0254000.1508520.150852
70.0223000.1592930.159293
80.0196000.1546080.154608
90.0173000.1561040.156104

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "IOPub message rate exceeded.\n", "The Jupyter server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--ServerApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "ServerApp.rate_limit_window=3.0 (secs)\n", "\n" ] } ], "source": [ "# 开始训练\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "060c4618-40d0-4934-bab8-36aab3a46de5", "metadata": {}, "outputs": [], "source": [ "#模型测试\n", "predictions = trainer.predict(tokenized_datasets[\"test\"])\n", "predictions" ] }, { "cell_type": "code", "execution_count": 18, "id": "1f8ef885-5bc9-4668-905b-6b2235209654", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [345/345 00:09]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.15949687361717224,\n", " 'eval_rmse': 0.15949687361717224,\n", " 'eval_runtime': 9.1483,\n", " 'eval_samples_per_second': 754.017,\n", " 'eval_steps_per_second': 37.712,\n", " 'epoch': 10.0}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate()" ] }, { "cell_type": "code", "execution_count": 23, "id": "afabdbe9-9b96-4f9e-bef2-1d819431f8d1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 1.7208484 ]\n", " [ 0.00225139]\n", " [ 0.3325616 ]\n", " [-0.34372616]\n", " [-0.45505935]\n", " [-0.06892765]\n", " [ 0.15099108]\n", " [ 0.12211376]\n", " [ 0.3947332 ]\n", " [ 0.23186803]]\n" ] } ], "source": [ "predictions.predictions[0:10].squeeze()" ] }, { "cell_type": "code", "execution_count": 24, "id": "fa9d17fd-eece-4c1e-99e0-3d19d36f7584", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 1.69, 0.84, 0.58, -0.15, 0.23, 0.03, 0.15, 0.2 , 0.51,\n", " 1.1 ], dtype=float32)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions.label_ids[0:10]" ] }, { "cell_type": "code", "execution_count": null, "id": "52252015-e068-414b-bd8a-79a5d1a2beec", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }