import transformers import torch import os import json import random import numpy as np import argparse from torch.utils.tensorboard import SummaryWriter from datetime import datetime from tqdm import tqdm from torch.nn import DataParallel import logging from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config from transformers import BertTokenizerFast # from transformers import BertTokenizer from os.path import join, exists from itertools import zip_longest, chain # from chatbot.model import DialogueGPT2Model from dataset import MyDataset from torch.utils.data import Dataset, DataLoader from torch.nn import CrossEntropyLoss from sklearn.model_selection import train_test_split import torch.nn.functional as F from fastapi import FastAPI import uvicorn PAD = '[PAD]' pad_id = 0 app = FastAPI() def set_args(): """ Sets up the arguments. """ parser = argparse.ArgumentParser() parser.add_argument('--device', default='0', type=str, required=False, help='生成设备') parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature') parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1') parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率') # parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False, # help='模型参数') parser.add_argument('--log_path', default='data/interact.log', type=str, required=False, help='interact日志存放位置') parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, help='选择词库') parser.add_argument('--model_path', default='model/epoch40', type=str, required=False, help='对话模型路径') parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径") parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的') parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断') parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度") parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') return parser.parse_args() def create_logger(args): """ 将日志输出到日志文件和控制台 """ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s') # 创建一个handler,用于写入日志文件 file_handler = logging.FileHandler( filename=args.log_path) file_handler.setFormatter(formatter) file_handler.setLevel(logging.INFO) logger.addHandler(file_handler) # 创建一个handler,用于将日志输出到控制台 console = logging.StreamHandler() console.setLevel(logging.DEBUG) console.setFormatter(formatter) logger.addHandler(console) return logger def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocab size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear top_k = min(top_k, logits.size(-1)) # Safety check if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k # torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices) # ...表示其他维度由计算机自行推断 indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷 if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits进行递减排序 cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value return logits def main(): args = set_args() logger = create_logger(args) # 当用户使用GPU,并且GPU可用时 args.cuda = torch.cuda.is_available() and not args.no_cuda device = 'cuda' if args.cuda else 'cpu' logger.info('using device:{}'.format(device)) os.environ["CUDA_VISIBLE_DEVICES"] = args.device tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") # tokenizer = BertTokenizer(vocab_file=args.voca_path) model = GPT2LMHeadModel.from_pretrained(args.model_path) model = model.to(device) model.eval() if args.save_samples_path: if not os.path.exists(args.save_samples_path): os.makedirs(args.save_samples_path) samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8') samples_file.write("聊天记录{}:\n".format(datetime.now())) # 存储聊天记录,每个utterance以token的id的形式进行存储 history = [] # print('开始和chatbot聊天,输入CTRL + Z以退出') @app.post('/') @app.get('/') async def chatbot_api_get(query: str): args = set_args() logger = create_logger(args) # 当用户使用GPU,并且GPU可用时 args.cuda = torch.cuda.is_available() and not args.no_cuda device = 'cuda' if args.cuda else 'cpu' logger.info('using device:{}'.format(device)) os.environ["CUDA_VISIBLE_DEVICES"] = args.device tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") # tokenizer = BertTokenizer(vocab_file=args.voca_path) model = GPT2LMHeadModel.from_pretrained(args.model_path) model = model.to(device) model.eval() if args.save_samples_path: if not os.path.exists(args.save_samples_path): os.makedirs(args.save_samples_path) samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8') samples_file.write("聊天记录{}:\n".format(datetime.now())) # 存储聊天记录,每个utterance以token的id的形式进行存储 history = [] import time ''' **A Succesful Request would return:**\n - __response:__ fills in the nested json with in this query\n - __bot:__ Bot's response to the desired query\n - __user:__ string user passes to the API\n - __time_taken:__ delay time for response from user to the server\n **Response Codes:**\n - __Response__ [`200`] - Success\n - __Response__ [`405`] - Method Not Allowed\n - __Response__ [`422`] - Unprocessable Entity ''' # try: # text = input("user:") # text = "你好" # main() text = query # if args.save_samples_path: # samples_file.write("user:{}\n".format(text)) text_ids = tokenizer.encode(text, add_special_tokens=False) history.append(text_ids) input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 for history_id, history_utr in enumerate(history[-args.max_history_len:]): input_ids.extend(history_utr) input_ids.append(tokenizer.sep_token_id) input_ids = torch.tensor(input_ids).long().to(device) input_ids = input_ids.unsqueeze(0) response = [] # 根据context,生成的response # 最多生成max_len个token for _ in range(args.max_len): outputs = model(input_ids=input_ids) logits = outputs.logits next_token_logits = logits[0, -1, :] # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 for id in set(response): next_token_logits[id] /= args.repetition_penalty next_token_logits = next_token_logits / args.temperature # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 break response.append(next_token.item()) input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist()) # print("his_text:{}".format(his_text)) history.append(response) text = tokenizer.convert_ids_to_tokens(response) # print("chatbot:" + "".join(text)) # if args.save_samples_path: # samples_file.write("chatbot:{}\n".format("".join(text))) start = time.time() return { 'response': { 'user' : query, 'bot': format("".join(text)), 'time_taken': str(f'{(round((time.time() - start)* 1000, 3))}ms') } } # while True: # try: # text = input("user:") # # text = "你好" # if args.save_samples_path: # samples_file.write("user:{}\n".format(text)) # text_ids = tokenizer.encode(text, add_special_tokens=False) # history.append(text_ids) # input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 # for history_id, history_utr in enumerate(history[-args.max_history_len:]): # input_ids.extend(history_utr) # input_ids.append(tokenizer.sep_token_id) # input_ids = torch.tensor(input_ids).long().to(device) # input_ids = input_ids.unsqueeze(0) # response = [] # 根据context,生成的response # # 最多生成max_len个token # for _ in range(args.max_len): # outputs = model(input_ids=input_ids) # logits = outputs.logits # next_token_logits = logits[0, -1, :] # # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 # for id in set(response): # next_token_logits[id] /= args.repetition_penalty # next_token_logits = next_token_logits / args.temperature # # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token # next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') # filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) # # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 # next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) # if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 # break # response.append(next_token.item()) # input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) # # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist()) # # print("his_text:{}".format(his_text)) # history.append(response) # text = tokenizer.convert_ids_to_tokens(response) # print("chatbot:" + "".join(text)) # if args.save_samples_path: # samples_file.write("chatbot:{}\n".format("".join(text))) # except KeyboardInterrupt: # if args.save_samples_path: # samples_file.close() # break if __name__ == '__main__': # main() uvicorn.run(app,host="0.0.0.0", port=3000, log_level="info")