Update main.py
Browse files
main.py
CHANGED
@@ -1,7 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
|
3 |
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
@app.
|
6 |
-
def
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import transformers
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from torch import nn
|
8 |
+
import argparse
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
from datetime import datetime
|
11 |
+
from tqdm import tqdm
|
12 |
+
from torch.nn import DataParallel
|
13 |
+
import logging
|
14 |
+
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config
|
15 |
+
from transformers import BertTokenizerFast
|
16 |
+
# from transformers import BertTokenizer
|
17 |
+
from os.path import join, exists
|
18 |
+
from itertools import zip_longest, chain
|
19 |
+
# from chatbot.model import DialogueGPT2Model
|
20 |
+
# from dataset import MyDataset
|
21 |
+
from torch.utils.data import Dataset, DataLoader
|
22 |
+
from torch.nn import CrossEntropyLoss
|
23 |
+
from sklearn.model_selection import train_test_split
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from transformers import AutoTokenizer, AutoConfig, get_linear_schedule_with_warmup, AdamW, BertModel
|
26 |
+
|
27 |
+
PAD = '[PAD]'
|
28 |
+
pad_id = 0
|
29 |
+
|
30 |
+
|
31 |
+
def set_args():
|
32 |
+
"""
|
33 |
+
Sets up the arguments.
|
34 |
+
"""
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
|
37 |
+
# parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
|
38 |
+
# help='模型参数')
|
39 |
+
parser.add_argument('--log_path', default='interact.log', type=str, required=False, help='interact日志存放位置')
|
40 |
+
parser.add_argument('--model_path', default='./pathology_extra/result/12/model.pth', type=str, required=False, help='对话模型路径')
|
41 |
+
parser.add_argument('--vocab_path', default='D:\\transformerFileDownload\\Pytorch\\bert-base-zh\\vocab.txt', type=str, required=False,
|
42 |
+
help='对话模型路径')
|
43 |
+
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
44 |
+
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
45 |
+
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
|
46 |
+
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
47 |
+
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
|
48 |
+
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
|
49 |
+
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
|
50 |
+
return parser.parse_args()
|
51 |
+
|
52 |
+
|
53 |
+
def create_logger(args):
|
54 |
+
"""
|
55 |
+
将日志输出到日志文件和控制台
|
56 |
+
"""
|
57 |
+
logger = logging.getLogger(__name__)
|
58 |
+
logger.setLevel(logging.INFO)
|
59 |
+
|
60 |
+
formatter = logging.Formatter(
|
61 |
+
'%(asctime)s - %(levelname)s - %(message)s')
|
62 |
+
|
63 |
+
# 创建一个handler,用于写入日志文件
|
64 |
+
file_handler = logging.FileHandler(
|
65 |
+
filename=args.log_path)
|
66 |
+
file_handler.setFormatter(formatter)
|
67 |
+
file_handler.setLevel(logging.INFO)
|
68 |
+
logger.addHandler(file_handler)
|
69 |
+
|
70 |
+
# 创建一个handler,用于将日志输出到控制台
|
71 |
+
console = logging.StreamHandler()
|
72 |
+
console.setLevel(logging.DEBUG)
|
73 |
+
console.setFormatter(formatter)
|
74 |
+
logger.addHandler(console)
|
75 |
+
|
76 |
+
return logger
|
77 |
+
|
78 |
+
class Word_BERT(nn.Module):
|
79 |
+
def __init__(self, seq_label=1,cancer_label=8,transfer_label=2,ly_transfer=2):
|
80 |
+
super(Word_BERT, self).__init__()
|
81 |
+
self.bert = BertModel.from_pretrained('D:\\transformerFileDownload\\Pytorch\\bert-base-zh')
|
82 |
+
# self.bert_config = self.bert.config
|
83 |
+
self.out = nn.Sequential(
|
84 |
+
# nn.Linear(768,256),
|
85 |
+
# nn.ReLU(),
|
86 |
+
nn.Dropout(0.1),
|
87 |
+
nn.Linear(768, seq_label)
|
88 |
+
)
|
89 |
+
self.cancer = nn.Sequential(
|
90 |
+
nn.Dropout(0.1),
|
91 |
+
nn.Linear(768, cancer_label)
|
92 |
+
)
|
93 |
+
self.transfer = nn.Sequential(
|
94 |
+
nn.Dropout(0.1),
|
95 |
+
nn.Linear(768, transfer_label)
|
96 |
+
)
|
97 |
+
self.ly_transfer = nn.Sequential(
|
98 |
+
nn.Dropout(0.1),
|
99 |
+
nn.Linear(768, ly_transfer)
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(self, word_input, masks):
|
103 |
+
# print(word_input.size())
|
104 |
+
output = self.bert(word_input, attention_mask=masks)
|
105 |
+
sequence_output = output.last_hidden_state
|
106 |
+
pool = output.pooler_output
|
107 |
+
# print(sequence_output.size())
|
108 |
+
# print(pool.size())
|
109 |
+
out = self.out(sequence_output)
|
110 |
+
cancer = self.cancer(pool)
|
111 |
+
transfer = self.transfer(pool)
|
112 |
+
ly_transfer = self.ly_transfer(pool)
|
113 |
+
return out,cancer,transfer,ly_transfer
|
114 |
+
|
115 |
+
def getChat(text: str, userid: int):
|
116 |
+
# while True:
|
117 |
+
# if True:
|
118 |
+
# text = input("user:")
|
119 |
+
# text = "你好"
|
120 |
+
# if args.save_samples_path:
|
121 |
+
# samples_file.write("user:{}\n".format(text))
|
122 |
+
text = ['[CLS]']+[i for i in text]+['[SEP]']
|
123 |
+
# print(text)
|
124 |
+
text_ids = tokenizer.convert_tokens_to_ids(text)
|
125 |
+
# print(text_ids)
|
126 |
+
|
127 |
+
input_ids = torch.tensor(text_ids).long().to(device)
|
128 |
+
input_ids = input_ids.unsqueeze(0)
|
129 |
+
mask_input = torch.ones_like(input_ids).long().to(device)
|
130 |
+
# print(input_ids.size())
|
131 |
+
response = [] # 根据context,生成的response
|
132 |
+
# 最多生成max_len个token
|
133 |
+
with torch.no_grad():
|
134 |
+
out, cancer, transfer, ly_transfer = model(input_ids, mask_input)
|
135 |
+
out = F.sigmoid(out).squeeze(2).cpu()
|
136 |
+
out = out.numpy().tolist()
|
137 |
+
cancer = cancer.argmax(dim=-1).cpu().numpy().tolist()
|
138 |
+
transfer = transfer.argmax(dim=-1).cpu().numpy().tolist()
|
139 |
+
ly_transfer = ly_transfer.argmax(dim=-1).cpu().numpy().tolist()
|
140 |
+
# print(out)
|
141 |
+
# print(cancer,transfer,ly_transfer)
|
142 |
+
|
143 |
+
pred_thresold = [[1 if jj > 0.4 else 0 for jj in ii] for ii in out]
|
144 |
+
size_list = []
|
145 |
+
start,end = 0,0
|
146 |
+
for i,j in enumerate(pred_thresold[0]):
|
147 |
+
if j==1 and start==end:
|
148 |
+
start = i
|
149 |
+
elif j!=1 and start!=end:
|
150 |
+
end = i
|
151 |
+
size_list.append((start,end))
|
152 |
+
start = end
|
153 |
+
print(size_list)
|
154 |
+
|
155 |
+
cancer_dict = {'腺癌': 0, '肺良性疾病': 1, '鳞癌': 2, '无法判断组织分型': 3, '复合型': 4, '转移癌': 5, '小细胞癌': 6, '大细胞癌': 7}
|
156 |
+
id_cancer = {j:i for i,j in cancer_dict.items()}
|
157 |
+
transfer_id = {'无': 0, '转移': 1}
|
158 |
+
id_transfer = {j:i for i,j in transfer_id.items()}
|
159 |
+
lymph_transfer_id = {'无': 0, '淋巴转移': 1}
|
160 |
+
id_lymph_transfer = {j: i for i, j in lymph_transfer_id.items()}
|
161 |
+
# print(cancer)
|
162 |
+
cancer = id_cancer[cancer[0]]
|
163 |
+
transfer = id_transfer[transfer[0]]
|
164 |
+
ly_transfer = id_lymph_transfer[ly_transfer[0]]
|
165 |
+
print(cancer,transfer,ly_transfer)
|
166 |
+
|
167 |
+
return size_list,cancer,transfer,ly_transfer
|
168 |
+
|
169 |
+
|
170 |
+
import requests
|
171 |
+
|
172 |
+
import uvicorn
|
173 |
+
from pydantic import BaseModel
|
174 |
from fastapi import FastAPI
|
175 |
|
176 |
app = FastAPI()
|
177 |
+
# import intel_extension_for_pytorch as ipex
|
178 |
+
|
179 |
+
args = set_args()
|
180 |
+
logger = create_logger(args)
|
181 |
+
# 当用户使用GPU,并且GPU可用时
|
182 |
+
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
183 |
+
device = 'cuda' if args.cuda else 'cpu'
|
184 |
+
logger.info('using device:{}'.format(device))
|
185 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
186 |
+
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
187 |
+
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
188 |
+
model = Word_BERT()
|
189 |
+
# model = model.load_state_dict(torch.load(args.model_path))
|
190 |
+
model = model.to(device)
|
191 |
+
model.eval()
|
192 |
+
# history = []
|
193 |
+
Allhistory = {}
|
194 |
+
print('初始化完成')
|
195 |
+
|
196 |
+
if __name__ == '__main__':
|
197 |
+
# getChat("测试一下", 0)
|
198 |
+
# main()
|
199 |
+
uvicorn.run(app='main:app', host="localhost",
|
200 |
+
port=7860, reload=False)
|
201 |
+
# testFunc()
|
202 |
+
|
203 |
+
|
204 |
+
class Items1(BaseModel):
|
205 |
+
context: str
|
206 |
+
userid: int
|
207 |
+
# must: bool
|
208 |
+
|
209 |
+
|
210 |
+
import time
|
211 |
+
|
212 |
+
lastReplyTime = 0
|
213 |
+
|
214 |
|
215 |
+
@app.post("/")
|
216 |
+
async def get_Chat(item1: Items1):
|
217 |
+
global lastReplyTime
|
218 |
+
tempReplyTime = int(time.time() * 1000)
|
219 |
+
# if tempReplyTime % 10 == 0 or item1.must == True or tempReplyTime - lastReplyTime < 30000:
|
220 |
+
# if item1.must == True:
|
221 |
+
# lastReplyTime = tempReplyTime
|
222 |
+
result = getChat(
|
223 |
+
item1.context, item1.userid)
|
224 |
+
return {"res": result}
|