""" | |
1. 这里用公网Qwen API来代替本地的ChatGLM模型。为了在Huggingface上演示。 | |
1. 使用确定了列名作为SQL语句的变量名,可以有效解决模型生成的SQL语句中变量名准确的问题。 | |
""" | |
##TODO: | |
import requests | |
import os | |
from rich import print | |
import os | |
import sys | |
import time | |
import pandas as pd | |
import numpy as np | |
import sys | |
import time | |
from typing import Any | |
import requests | |
import csv | |
import os | |
from rich import print | |
import pandas | |
import io | |
from io import StringIO | |
import re | |
from langchain.llms.utils import enforce_stop_tokens | |
import json | |
from transformers import AutoModel, AutoTokenizer | |
import mdtex2html | |
import qwen_response | |
''' Start: Environment settings. ''' | |
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/' | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
import torch | |
mps_device = torch.device("mps") ## 在mac机器上需要加上这句。必须要有这句,否则会报错。 | |
### 在langchain中定义chatGLM作为LLM。 | |
from typing import Any, List, Mapping, Optional | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from langchain.llms.base import LLM | |
from transformers import AutoTokenizer, AutoModel | |
# llm_filepath = str("/Users/yunshi/Downloads/chatGLM/ChatGLM3-6B/6B") ## 第三代chatGLM 6B W/ code-interpreter | |
# ## API模式启动ChatGLM | |
# ## 配置ChatGLM的类与后端api server对应。 | |
# class ChatGLM(LLM): | |
# max_token: int = 2048 | |
# temperature: float = 0.1 | |
# top_p = 0.9 | |
# history = [] | |
# def __init__(self): | |
# super().__init__() | |
# @property | |
# def _llm_type(self) -> str: | |
# return "ChatGLM" | |
# def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
# # headers中添加上content-type这个参数,指定为json格式 | |
# headers = {'Content-Type': 'application/json'} | |
# data=json.dumps({ | |
# 'prompt':prompt, | |
# 'temperature':self.temperature, | |
# 'history':self.history, | |
# 'max_length':self.max_token | |
# }) | |
# print("ChatGLM prompt:",prompt) | |
# # 调用api | |
# # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。 | |
# response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。 | |
# print("ChatGLM resp:", response) | |
# if response.status_code!=200: | |
# return "查询结果错误" | |
# resp = response.json() | |
# if stop is not None: | |
# response = enforce_stop_tokens(response, stop) | |
# self.history = self.history+[[None, resp['response']]] ##original | |
# return resp['response'] ##original. | |
# llm = ChatGLM() ## 启动一个实例。orignal working。 | |
# import asyncio | |
# llm = ChatGLM() ## 启动一个实例。 | |
''' End: Environment settings. ''' | |
### 我会用中文或者英文双引号(即:“ ”," ")来告知你变量的名称。 长度","宽度","价格","产品ID","比率","类别","*" | |
### 用ChatGLM构建一个只返回SQL语句的模型。 | |
def main(prompt): | |
full_reponse = [] | |
sys_prompt = """ | |
1. 你是一个将文字转换成SQL语句的人工智能。 | |
2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。 | |
3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*" | |
4. 你不能写IF, THEN的SQL语句,需要使用CASE。 | |
5. 我需要你转换的文字如下:""" | |
total_prompt = sys_prompt + "在数据表格table01中," + prompt | |
print('total prompt now:',total_prompt) | |
# for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。 | |
# for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。 | |
# for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(total_prompt)): ## 这里保留了所有的chat history在input_prompt中。 | |
# # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0])): ## 从用langchain的自定义方式来做。 | |
# # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0]), history=input_prompt, max_length=max_tokens, top_p=top_p, temperature=temperature): ## 从用langchain的自定义方式来做。 | |
# if response != "<br>": | |
# # print('response of model:', response) | |
# # input_prompt[-1][1] = response ## working. | |
# # input_prompt[-1][1] = response | |
# # yield input_prompt | |
# full_reponse.append(response) | |
# ## 得到一个非stream格式的答复。非API模式。 | |
# response, history = chatglm.model.chat(chatglm.tokenizer, query=str(total_prompt), temperature=0.1) ## 这里保留了所有的chat history在input_prompt中。 | |
###TODO:API模式,需要先启动API服务器。 | |
# llm = ChatGLM() ##!! 重要说明:每次都需要实例化一次!!!否则会报错content error。实际上是应该在每次函数调用的时候都要实例化一次! | |
# response = llm(total_prompt) ## 这里是本地的ChatGLM来作为大模型输出基座。 | |
response = qwen_response.call_with_messages(total_prompt) | |
print('response of model:', response) | |
## 用regex来提取纯SQL语句。需要构建多个正则式pattern | |
pattern_1 = r"(?:`sql\n|\n`)" | |
pattern_2 = r"(?:```|``)" | |
pattern_3 = r"(?s)(.*?SQL语句示例.*?:).*?\n" | |
pattern_4 = r"(?:`{3}|`{2}|`)" | |
# pattern_5 = r"[\u4e00-\u9FFF]" ## 匹配中文。 | |
# pattern_6 = r"^[\u4e00-\u9fa5]{5,}" ## 首行中包含5个中文汉字的。 | |
pattern_7 = r"^.{0,2}([\u4e00-\u9fa5]{5,}).*" ## 首行中包含5个中文汉字的。 | |
pattern_8 = r'^"|"$' ## 去除一句话开始或者末尾的英文双引号 | |
pattern_list = [pattern_1, pattern_2, pattern_3, pattern_4, pattern_7, pattern_8] | |
## 遍历所有的pattern,逐个去除。 | |
full_reponse = response | |
for p in pattern_list: | |
full_reponse = re.sub(p, "", full_reponse) | |
# final_response = re.sub(pattern_1, "", response) ## 逐步匹配。 | |
# final_response = re.sub(pattern_1, "", response) ## 逐步匹配。 | |
# final_response = re.sub(pattern_2, "", final_response) ## 逐步匹配。 | |
return full_reponse | |
# prompt = "你给我一段复杂的SQL语句示例。" | |
# prompt = "你给我一段SQL语句,用来完成如下工作:查询年龄大于30岁,男性,收入超过2万元的员工。" | |
# res = main(prompt=prompt) | |
# print(res) | |