File size: 6,974 Bytes
5255efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
1. 这里用公网Qwen API来代替本地的ChatGLM模型。为了在Huggingface上演示。
1. 使用确定了列名作为SQL语句的变量名,可以有效解决模型生成的SQL语句中变量名准确的问题。

"""
##TODO: 

import requests
import os
from rich import print
import os
import sys
import time
import pandas as pd
import numpy as np
import sys
import time
from typing import Any
import requests
import csv
import os
from rich import print
import pandas
import io
from io import StringIO
import re
from langchain.llms.utils import enforce_stop_tokens
import json
from transformers import AutoModel, AutoTokenizer
import mdtex2html
import qwen_response


''' Start: Environment settings. '''
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/'
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
mps_device = torch.device("mps") ## 在mac机器上需要加上这句。必须要有这句,否则会报错。

### 在langchain中定义chatGLM作为LLM。
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from transformers import AutoTokenizer, AutoModel
# llm_filepath = str("/Users/yunshi/Downloads/chatGLM/ChatGLM3-6B/6B") ## 第三代chatGLM 6B W/ code-interpreter


# ## API模式启动ChatGLM
# ## 配置ChatGLM的类与后端api server对应。
# class ChatGLM(LLM):
#     max_token: int = 2048
#     temperature: float = 0.1
#     top_p = 0.9
#     history = []

#     def __init__(self):
#         super().__init__()

#     @property
#     def _llm_type(self) -> str:
#         return "ChatGLM"

#     def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
#         # headers中添加上content-type这个参数,指定为json格式
#         headers = {'Content-Type': 'application/json'}
#         data=json.dumps({
#             'prompt':prompt,
#             'temperature':self.temperature,
#             'history':self.history,
#             'max_length':self.max_token
#         })
#         print("ChatGLM prompt:",prompt)
#         # 调用api
#         # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。
#         response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。
#         print("ChatGLM resp:", response)
        
#         if response.status_code!=200:
#             return "查询结果错误"
#         resp = response.json()
#         if stop is not None:
#             response = enforce_stop_tokens(response, stop)
#         self.history = self.history+[[None, resp['response']]] ##original
#         return resp['response'] ##original.

# llm = ChatGLM() ## 启动一个实例。orignal working。
# import asyncio
# llm =  ChatGLM() ## 启动一个实例。


''' End: Environment settings. '''

### 我会用中文或者英文双引号(即:“ ”," ")来告知你变量的名称。 长度","宽度","价格","产品ID","比率","类别","*"

### 用ChatGLM构建一个只返回SQL语句的模型。
def main(prompt):
    full_reponse = []
    sys_prompt = """
    1. 你是一个将文字转换成SQL语句的人工智能。
    2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。
    3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*"
    4. 你不能写IF, THEN的SQL语句,需要使用CASE。
    5. 我需要你转换的文字如下:"""   
    
    total_prompt = sys_prompt + "在数据表格table01中," + prompt
    
    print('total prompt now:',total_prompt)
    # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。
    # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。
    # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(total_prompt)): ## 这里保留了所有的chat history在input_prompt中。

            # # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0])): ## 从用langchain的自定义方式来做。
            # # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0]), history=input_prompt, max_length=max_tokens, top_p=top_p, temperature=temperature): ## 从用langchain的自定义方式来做。
            # if response != "<br>":
            #     # print('response of model:', response)
            #     # input_prompt[-1][1] = response ## working.
            #     # input_prompt[-1][1] = response
            #     # yield input_prompt
        # full_reponse.append(response)
    
    # ## 得到一个非stream格式的答复。非API模式。    
    # response, history = chatglm.model.chat(chatglm.tokenizer, query=str(total_prompt), temperature=0.1) ## 这里保留了所有的chat history在input_prompt中。
    
    ###TODO:API模式,需要先启动API服务器。
    # llm =  ChatGLM() ##!! 重要说明:每次都需要实例化一次!!!否则会报错content error。实际上是应该在每次函数调用的时候都要实例化一次!
    # response = llm(total_prompt) ## 这里是本地的ChatGLM来作为大模型输出基座。
    
    
    response = qwen_response.call_with_messages(total_prompt)
    
    print('response of model:', response)
    
    ## 用regex来提取纯SQL语句。需要构建多个正则式pattern
    pattern_1 = r"(?:`sql\n|\n`)"
    pattern_2 = r"(?:```|``)"
    pattern_3 = r"(?s)(.*?SQL语句示例.*?:).*?\n"
    pattern_4 = r"(?:`{3}|`{2}|`)"
    # pattern_5 = r"[\u4e00-\u9FFF]" ## 匹配中文。
    # pattern_6 = r"^[\u4e00-\u9fa5]{5,}" ## 首行中包含5个中文汉字的。
    pattern_7 = r"^.{0,2}([\u4e00-\u9fa5]{5,}).*" ## 首行中包含5个中文汉字的。 
    pattern_8 = r'^"|"$' ## 去除一句话开始或者末尾的英文双引号 
    pattern_list = [pattern_1, pattern_2, pattern_3, pattern_4, pattern_7, pattern_8]
    
    ## 遍历所有的pattern,逐个去除。
    full_reponse = response
    for p in pattern_list:
        full_reponse = re.sub(p, "", full_reponse)
    # final_response = re.sub(pattern_1, "", response) ## 逐步匹配。
    # final_response = re.sub(pattern_1, "", response) ## 逐步匹配。
    # final_response = re.sub(pattern_2, "", final_response)  ## 逐步匹配。

    return full_reponse

# prompt = "你给我一段复杂的SQL语句示例。"
# prompt = "你给我一段SQL语句,用来完成如下工作:查询年龄大于30岁,男性,收入超过2万元的员工。"
# res = main(prompt=prompt)
# print(res)