{ "cells": [ { "cell_type": "markdown", "id": "b1b37ca8-25a3-440c-9b68-7f72ce670ade", "metadata": {}, "source": [ "# 2.4 基因大模型的生物序列特征提取" ] }, { "cell_type": "markdown", "id": "d3d04215-2b6c-41fb-92a4-90c82d322ba4", "metadata": {}, "source": [ "使用 GPT-2 模型获取文本的特征向量是一个常见的需求,尤其是在进行文本分类、相似度计算或其他下游任务时。Hugging Face 的 transformers 库提供了简单易用的接口来实现这一点。以下是详细的步骤和代码示例,帮助你从 GPT-2 模型中提取文本的特征向量。" ] }, { "cell_type": "markdown", "id": "3ff5b7c6-e57c-4839-8510-f764154faa65", "metadata": {}, "source": [ "使用 GPT-2 模型获取文本的特征向量是一个常见的需求,尤其是在进行文本分类、相似度计算或其他下游任务时。Hugging Face 的 `transformers` 库提供了简单易用的接口来实现这一点。以下是详细的步骤和代码示例,帮助你从 GPT-2 模型中提取文本的特征向量。\n", "\n", "### 方法 1: 使用隐藏状态(Hidden States)\n", "\n", "GPT-2 是一个基于 Transformer 的语言模型,它在每一层都有隐藏状态(hidden states),这些隐藏状态可以作为文本的特征表示。你可以选择最后一层的隐藏状态作为最终的特征向量,或者对多层的隐藏状态进行平均或拼接。\n", "\n", "\n", "### 方法 2: 使用池化策略\n", "\n", "另一种方法是通过对所有 token 的隐藏状态进行池化操作来获得句子级别的特征向量。常见的池化方法包括:\n", "\n", "- **均值池化**(Mean Pooling):对所有 token 的隐藏状态求平均。\n", "- **最大池化**(Max Pooling):对每个维度取最大值。" ] }, { "cell_type": "code", "execution_count": 43, "id": "e7fe053b-d6da-488a-9c62-24e4b40a992d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'input_ids': tensor([[ 1, 191, 29, 753, 1241, 2104, 12297, 357, 85, 4395,\n", " 26392, 16]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n", "torch.Size([768])\n", "torch.Size([768])\n", "torch.Size([768])\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModel\n", "tokenizer = AutoTokenizer.from_pretrained('dna_bpe_dict')\n", "tokenizer.tokenize(\"GAGCACATTCGCCTGCGTGCGCACTCACACACACGTTCAAAAAGAGTCCATTCGATTCTGGCAGTAG\")\n", "#result: [G','AGCAC','ATTCGCC',....]\n", "\n", "model = AutoModel.from_pretrained('dna_gpt2_v0')\n", "import torch\n", "dna = \"ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC\"\n", "inputs = tokenizer(dna, return_tensors = 'pt')\n", "print(inputs)\n", "\n", "outputs = model(inputs[\"input_ids\"])\n", "#outputs = model(**inputs)\n", "\n", "hidden_states = outputs.last_hidden_state # [1, sequence_length, 768] outputs.last_hidden_state or outputs[0]\n", "\n", "# embedding with mean pooling\n", "embedding_mean = torch.mean(hidden_states[0], dim=0)\n", "print(embedding_mean.shape) # expect to be 768\n", "\n", "# embedding with max pooling\n", "embedding_max = torch.max(hidden_states[0], dim=0)[0]\n", "print(embedding_max.shape) # expect to be 768\n", "\n", "# embedding with first token\n", "embedding_first_token = hidden_states[0][0]\n", "print(embedding_first_token.shape) # expect to be 768" ] }, { "cell_type": "code", "execution_count": 44, "id": "a1f2b545-283a-4613-a953-beb82f427826", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertModel were not initialized from the model checkpoint at dna_bert_v0 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'input_ids': tensor([[ 6, 200, 16057, 10, 1256, 2123, 12294, 366, 13138, 7826,\n", " 82, 25]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}\n", "torch.Size([768])\n", "torch.Size([768])\n", "torch.Size([768])\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModel\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained('dna_wordpiece_dict')\n", "tokenizer.tokenize(\"GAGCACATTCGCCTGCGTGCGCACTCACACACACGTTCAAAAAGAGTCCATTCGATTCTGGCAGTAG\")\n", "#result: [G','AGCAC','ATTCGCC',....]\n", "\n", "model = AutoModel.from_pretrained('dna_bert_v0')\n", "dna = \"ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC\"\n", "inputs = tokenizer(dna, return_tensors = 'pt')\n", "print(inputs)\n", "\n", "outputs = model(inputs[\"input_ids\"])\n", "#outputs = model(**inputs)\n", "\n", "hidden_states = outputs.last_hidden_state # [1, sequence_length, 768] outputs.last_hidden_state or outputs[0]\n", "\n", "# embedding with mean pooling\n", "embedding_mean = torch.mean(hidden_states[0], dim=0)\n", "print(embedding_mean.shape) # expect to be 768\n", "\n", "# embedding with max pooling\n", "embedding_max = torch.max(hidden_states[0], dim=0)[0]\n", "print(embedding_max.shape) # expect to be 768\n", "\n", "# embedding with first token\n", "embedding_first_token = hidden_states[0][0]\n", "print(embedding_first_token.shape) # expect to be 768" ] }, { "cell_type": "markdown", "id": "56761874-9af7-4b90-aa8b-131e5b8c69b6", "metadata": {}, "source": [ "## 特征提取并分类\n", "\n", "我们使用第一章中的\"dnagpt/dna_core_promoter\"数据集,演示下使用我们训练的DNA GPT2或者DNA bert模型,提取序列特征,然使用最基础的逻辑回归分类方法,对序列进行分类。" ] }, { "cell_type": "code", "execution_count": 45, "id": "f1ca177c-a80f-48a1-b2f9-16c13b3350db", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\"\\nimport os\\n\\n# 设置环境变量, autodl专区 其他idc\\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\\n\\n# 打印环境变量以确认设置成功\\nprint(os.environ.get('HF_ENDPOINT'))\\n\"" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import subprocess\n", "import os\n", "# 设置环境变量, autodl一般区域\n", "result = subprocess.run('bash -c \"source /etc/network_turbo && env | grep proxy\"', shell=True, capture_output=True, text=True)\n", "output = result.stdout\n", "for line in output.splitlines():\n", " if '=' in line:\n", " var, value = line.split('=', 1)\n", " os.environ[var] = value\n", "\n", "#或者\n", "\"\"\"\n", "import os\n", "\n", "# 设置环境变量, autodl专区 其他idc\n", "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n", "\n", "# 打印环境变量以确认设置成功\n", "print(os.environ.get('HF_ENDPOINT'))\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 46, "id": "2295739c-e80a-47be-9400-88bfab4b0bb6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['sequence', 'label'],\n", " num_rows: 59196\n", " })\n", "})" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "dna_data = load_dataset(\"dnagpt/dna_core_promoter\")\n", "dna_data" ] }, { "cell_type": "markdown", "id": "c804bced-f151-43a7-8a95-156db358da3e", "metadata": {}, "source": [ "这里,我们不需要关注这个数据的具体生物学含义,只需知道sequence是具体的DNA序列,label是分类标签,有两个类别0和1即可" ] }, { "cell_type": "code", "execution_count": 47, "id": "9a47a1b1-21f2-4d71-801c-50f88e326ed3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'sequence': 'CATGCGGGTCGATATCCTATCTGAATCTCTCAGCCCAAGAGGGAGTCCGCTCATCTATTCGGCAGTACTG',\n", " 'label': 0}" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dna_data[\"train\"][0]" ] }, { "cell_type": "markdown", "id": "cde7986d-a225-41ca-8f11-614d079fd2bf", "metadata": {}, "source": [ "这里使用scikit-learn库来构建逻辑回归分类器。首先是特征提取:" ] }, { "cell_type": "code", "execution_count": 52, "id": "4010d991-056a-43ce-8cca-30eeec8678f5", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.datasets import load_iris\n", "from sklearn.metrics import accuracy_score\n", "\n", "\n", "def get_gpt2_feature(sequence):\n", " return \n", "\n", "# 加载数据集\n", "data = load_iris()\n", "X = data.data[data.target < 2] # 只选择前两个类别\n", "y = data.target[data.target < 2]\n", "\n", "X = []\n", "Y = []\n", "\n", "for item in dna_data[\"train\"]:\n", " sequence = item[\"sequence\"]\n", " label = item[\"label\"]\n", " x_v = get_gpt2_feature(sequence)\n", " y_v = label\n", " X.append(x_v)\n", " Y.append(y_v)" ] }, { "cell_type": "code", "execution_count": 49, "id": "8af0effa-b2b6-4e49-9256-cead146d848c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[5.1, 3.5, 1.4, 0.2],\n", " [4.9, 3. , 1.4, 0.2],\n", " [4.7, 3.2, 1.3, 0.2],\n", " [4.6, 3.1, 1.5, 0.2],\n", " [5. , 3.6, 1.4, 0.2],\n", " [5.4, 3.9, 1.7, 0.4],\n", " [4.6, 3.4, 1.4, 0.3],\n", " [5. , 3.4, 1.5, 0.2],\n", " [4.4, 2.9, 1.4, 0.2],\n", " [4.9, 3.1, 1.5, 0.1],\n", " [5.4, 3.7, 1.5, 0.2],\n", " [4.8, 3.4, 1.6, 0.2],\n", " [4.8, 3. , 1.4, 0.1],\n", " [4.3, 3. , 1.1, 0.1],\n", " [5.8, 4. , 1.2, 0.2],\n", " [5.7, 4.4, 1.5, 0.4],\n", " [5.4, 3.9, 1.3, 0.4],\n", " [5.1, 3.5, 1.4, 0.3],\n", " [5.7, 3.8, 1.7, 0.3],\n", " [5.1, 3.8, 1.5, 0.3],\n", " [5.4, 3.4, 1.7, 0.2],\n", " [5.1, 3.7, 1.5, 0.4],\n", " [4.6, 3.6, 1. , 0.2],\n", " [5.1, 3.3, 1.7, 0.5],\n", " [4.8, 3.4, 1.9, 0.2],\n", " [5. , 3. , 1.6, 0.2],\n", " [5. , 3.4, 1.6, 0.4],\n", " [5.2, 3.5, 1.5, 0.2],\n", " [5.2, 3.4, 1.4, 0.2],\n", " [4.7, 3.2, 1.6, 0.2],\n", " [4.8, 3.1, 1.6, 0.2],\n", " [5.4, 3.4, 1.5, 0.4],\n", " [5.2, 4.1, 1.5, 0.1],\n", " [5.5, 4.2, 1.4, 0.2],\n", " [4.9, 3.1, 1.5, 0.2],\n", " [5. , 3.2, 1.2, 0.2],\n", " [5.5, 3.5, 1.3, 0.2],\n", " [4.9, 3.6, 1.4, 0.1],\n", " [4.4, 3. , 1.3, 0.2],\n", " [5.1, 3.4, 1.5, 0.2],\n", " [5. , 3.5, 1.3, 0.3],\n", " [4.5, 2.3, 1.3, 0.3],\n", " [4.4, 3.2, 1.3, 0.2],\n", " [5. , 3.5, 1.6, 0.6],\n", " [5.1, 3.8, 1.9, 0.4],\n", " [4.8, 3. , 1.4, 0.3],\n", " [5.1, 3.8, 1.6, 0.2],\n", " [4.6, 3.2, 1.4, 0.2],\n", " [5.3, 3.7, 1.5, 0.2],\n", " [5. , 3.3, 1.4, 0.2],\n", " [7. , 3.2, 4.7, 1.4],\n", " [6.4, 3.2, 4.5, 1.5],\n", " [6.9, 3.1, 4.9, 1.5],\n", " [5.5, 2.3, 4. , 1.3],\n", " [6.5, 2.8, 4.6, 1.5],\n", " [5.7, 2.8, 4.5, 1.3],\n", " [6.3, 3.3, 4.7, 1.6],\n", " [4.9, 2.4, 3.3, 1. ],\n", " [6.6, 2.9, 4.6, 1.3],\n", " [5.2, 2.7, 3.9, 1.4],\n", " [5. , 2. , 3.5, 1. ],\n", " [5.9, 3. , 4.2, 1.5],\n", " [6. , 2.2, 4. , 1. ],\n", " [6.1, 2.9, 4.7, 1.4],\n", " [5.6, 2.9, 3.6, 1.3],\n", " [6.7, 3.1, 4.4, 1.4],\n", " [5.6, 3. , 4.5, 1.5],\n", " [5.8, 2.7, 4.1, 1. ],\n", " [6.2, 2.2, 4.5, 1.5],\n", " [5.6, 2.5, 3.9, 1.1],\n", " [5.9, 3.2, 4.8, 1.8],\n", " [6.1, 2.8, 4. , 1.3],\n", " [6.3, 2.5, 4.9, 1.5],\n", " [6.1, 2.8, 4.7, 1.2],\n", " [6.4, 2.9, 4.3, 1.3],\n", " [6.6, 3. , 4.4, 1.4],\n", " [6.8, 2.8, 4.8, 1.4],\n", " [6.7, 3. , 5. , 1.7],\n", " [6. , 2.9, 4.5, 1.5],\n", " [5.7, 2.6, 3.5, 1. ],\n", " [5.5, 2.4, 3.8, 1.1],\n", " [5.5, 2.4, 3.7, 1. ],\n", " [5.8, 2.7, 3.9, 1.2],\n", " [6. , 2.7, 5.1, 1.6],\n", " [5.4, 3. , 4.5, 1.5],\n", " [6. , 3.4, 4.5, 1.6],\n", " [6.7, 3.1, 4.7, 1.5],\n", " [6.3, 2.3, 4.4, 1.3],\n", " [5.6, 3. , 4.1, 1.3],\n", " [5.5, 2.5, 4. , 1.3],\n", " [5.5, 2.6, 4.4, 1.2],\n", " [6.1, 3. , 4.6, 1.4],\n", " [5.8, 2.6, 4. , 1.2],\n", " [5. , 2.3, 3.3, 1. ],\n", " [5.6, 2.7, 4.2, 1.3],\n", " [5.7, 3. , 4.2, 1.2],\n", " [5.7, 2.9, 4.2, 1.3],\n", " [6.2, 2.9, 4.3, 1.3],\n", " [5.1, 2.5, 3. , 1.1],\n", " [5.7, 2.8, 4.1, 1.3]])" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X" ] }, { "cell_type": "code", "execution_count": 51, "id": "868a3cab-e991-4990-9ec5-3e632a41a599", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": null, "id": "5ab0c188-6476-43c4-b361-a2bfe0ec7a8a", "metadata": {}, "outputs": [], "source": [ "# 将数据分为训练集和测试集\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", "# 创建逻辑回归模型\n", "model = LogisticRegression()\n", "\n", "# 训练模型\n", "model.fit(X_train, y_train)\n", "\n", "# 在测试集上进行预测\n", "y_pred = model.predict(X_test)\n", "\n", "# 计算准确率\n", "accuracy = accuracy_score(y_test, y_pred)\n", "print(f\"Accuracy: {accuracy * 100:.2f}%\")\n", "\n", "# 输出部分预测结果与真实标签对比\n", "for i in range(5):\n", " print(f\"True: {y_test[i]}, Predicted: {y_pred[i]}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }