Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
2d5d187
1
Parent(s):
0127941
支持本地embedding
Browse files- modules/base_model.py +22 -9
- modules/config.py +2 -0
- modules/llama_func.py +37 -14
- modules/models.py +3 -0
- requirements.txt +1 -0
modules/base_model.py
CHANGED
@@ -132,8 +132,8 @@ class BaseLLMModel:
|
|
132 |
status_text = self.token_message()
|
133 |
yield get_return_value()
|
134 |
if self.interrupted:
|
135 |
-
|
136 |
-
|
137 |
self.history.append(construct_assistant(partial_text))
|
138 |
|
139 |
def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
|
@@ -170,7 +170,14 @@ class BaseLLMModel:
|
|
170 |
): # repetition_penalty, top_k
|
171 |
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
172 |
from llama_index.indices.query.schema import QueryBundle
|
173 |
-
from langchain.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
logging.info(
|
176 |
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
@@ -182,20 +189,22 @@ class BaseLLMModel:
|
|
182 |
old_inputs = None
|
183 |
display_reference = []
|
184 |
limited_context = False
|
185 |
-
if files
|
186 |
limited_context = True
|
187 |
old_inputs = inputs
|
188 |
msg = "加载索引中……(这可能需要几分钟)"
|
189 |
logging.info(msg)
|
190 |
yield chatbot + [(inputs, "")], msg
|
191 |
index = construct_index(self.api_key, file_src=files)
|
|
|
192 |
msg = "索引构建完成,获取回答中……"
|
|
|
|
|
|
|
|
|
193 |
logging.info(msg)
|
194 |
yield chatbot + [(inputs, "")], msg
|
195 |
with retrieve_proxy():
|
196 |
-
llm_predictor = LLMPredictor(
|
197 |
-
llm=OpenAIChat(temperature=0, model_name=self.model_name)
|
198 |
-
)
|
199 |
prompt_helper = PromptHelper(
|
200 |
max_input_size=4096,
|
201 |
num_output=5,
|
@@ -205,7 +214,7 @@ class BaseLLMModel:
|
|
205 |
from llama_index import ServiceContext
|
206 |
|
207 |
service_context = ServiceContext.from_defaults(
|
208 |
-
|
209 |
)
|
210 |
query_object = GPTVectorStoreIndexQuery(
|
211 |
index.index_struct,
|
@@ -249,7 +258,11 @@ class BaseLLMModel:
|
|
249 |
else:
|
250 |
display_reference = ""
|
251 |
|
252 |
-
if
|
|
|
|
|
|
|
|
|
253 |
status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
|
254 |
logging.info(status_text)
|
255 |
chatbot.append((inputs, ""))
|
|
|
132 |
status_text = self.token_message()
|
133 |
yield get_return_value()
|
134 |
if self.interrupted:
|
135 |
+
self.recover()
|
136 |
+
break
|
137 |
self.history.append(construct_assistant(partial_text))
|
138 |
|
139 |
def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
|
|
|
170 |
): # repetition_penalty, top_k
|
171 |
from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
|
172 |
from llama_index.indices.query.schema import QueryBundle
|
173 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
174 |
+
from langchain.chat_models import ChatOpenAI
|
175 |
+
from llama_index import (
|
176 |
+
GPTSimpleVectorIndex,
|
177 |
+
ServiceContext,
|
178 |
+
LangchainEmbedding,
|
179 |
+
OpenAIEmbedding,
|
180 |
+
)
|
181 |
|
182 |
logging.info(
|
183 |
"输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
|
|
|
189 |
old_inputs = None
|
190 |
display_reference = []
|
191 |
limited_context = False
|
192 |
+
if files:
|
193 |
limited_context = True
|
194 |
old_inputs = inputs
|
195 |
msg = "加载索引中……(这可能需要几分钟)"
|
196 |
logging.info(msg)
|
197 |
yield chatbot + [(inputs, "")], msg
|
198 |
index = construct_index(self.api_key, file_src=files)
|
199 |
+
assert index is not None, "索引构建失败"
|
200 |
msg = "索引构建完成,获取回答中……"
|
201 |
+
if local_embedding:
|
202 |
+
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
203 |
+
else:
|
204 |
+
embed_model = OpenAIEmbedding()
|
205 |
logging.info(msg)
|
206 |
yield chatbot + [(inputs, "")], msg
|
207 |
with retrieve_proxy():
|
|
|
|
|
|
|
208 |
prompt_helper = PromptHelper(
|
209 |
max_input_size=4096,
|
210 |
num_output=5,
|
|
|
214 |
from llama_index import ServiceContext
|
215 |
|
216 |
service_context = ServiceContext.from_defaults(
|
217 |
+
prompt_helper=prompt_helper, embed_model=embed_model
|
218 |
)
|
219 |
query_object = GPTVectorStoreIndexQuery(
|
220 |
index.index_struct,
|
|
|
258 |
else:
|
259 |
display_reference = ""
|
260 |
|
261 |
+
if (
|
262 |
+
self.api_key is not None
|
263 |
+
and len(self.api_key) == 0
|
264 |
+
and not shared.state.multi_api_key
|
265 |
+
):
|
266 |
status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
|
267 |
logging.info(status_text)
|
268 |
chatbot.append((inputs, ""))
|
modules/config.py
CHANGED
@@ -117,6 +117,8 @@ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
|
|
117 |
os.environ["HTTP_PROXY"] = ""
|
118 |
os.environ["HTTPS_PROXY"] = ""
|
119 |
|
|
|
|
|
120 |
@contextmanager
|
121 |
def retrieve_proxy(proxy=None):
|
122 |
"""
|
|
|
117 |
os.environ["HTTP_PROXY"] = ""
|
118 |
os.environ["HTTPS_PROXY"] = ""
|
119 |
|
120 |
+
local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
|
121 |
+
|
122 |
@contextmanager
|
123 |
def retrieve_proxy(proxy=None):
|
124 |
"""
|
modules/llama_func.py
CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
|
|
|
|
18 |
|
19 |
def get_index_name(file_src):
|
20 |
file_paths = [x.name for x in file_src]
|
@@ -28,6 +30,7 @@ def get_index_name(file_src):
|
|
28 |
|
29 |
return md5_hash.hexdigest()
|
30 |
|
|
|
31 |
def block_split(text):
|
32 |
blocks = []
|
33 |
while len(text) > 0:
|
@@ -35,6 +38,7 @@ def block_split(text):
|
|
35 |
text = text[1000:]
|
36 |
return blocks
|
37 |
|
|
|
38 |
def get_documents(file_src):
|
39 |
documents = []
|
40 |
logging.debug("Loading documents...")
|
@@ -50,11 +54,12 @@ def get_documents(file_src):
|
|
50 |
try:
|
51 |
from modules.pdf_func import parse_pdf
|
52 |
from modules.config import advance_docs
|
|
|
53 |
two_column = advance_docs["pdf"].get("two_column", False)
|
54 |
pdftext = parse_pdf(filepath, two_column).text
|
55 |
except:
|
56 |
pdftext = ""
|
57 |
-
with open(filepath,
|
58 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
59 |
for page in tqdm(pdfReader.pages):
|
60 |
pdftext += page.extract_text()
|
@@ -91,19 +96,21 @@ def get_documents(file_src):
|
|
91 |
|
92 |
|
93 |
def construct_index(
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
):
|
103 |
from langchain.chat_models import ChatOpenAI
|
104 |
-
from
|
|
|
105 |
|
106 |
-
|
|
|
107 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
108 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
109 |
separator = " " if separator == "" else separator
|
@@ -111,7 +118,14 @@ def construct_index(
|
|
111 |
llm_predictor = LLMPredictor(
|
112 |
llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
|
113 |
)
|
114 |
-
prompt_helper = PromptHelper(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
index_name = get_index_name(file_src)
|
116 |
if os.path.exists(f"./index/{index_name}.json"):
|
117 |
logging.info("找到了缓存的索引文件,加载中……")
|
@@ -119,11 +133,20 @@ def construct_index(
|
|
119 |
else:
|
120 |
try:
|
121 |
documents = get_documents(file_src)
|
|
|
|
|
|
|
|
|
122 |
logging.info("构建索引中……")
|
123 |
with retrieve_proxy():
|
124 |
-
service_context = ServiceContext.from_defaults(
|
|
|
|
|
|
|
|
|
|
|
125 |
index = GPTSimpleVectorIndex.from_documents(
|
126 |
-
documents,
|
127 |
)
|
128 |
logging.debug("索引构建完成!")
|
129 |
os.makedirs("./index", exist_ok=True)
|
|
|
15 |
|
16 |
from modules.presets import *
|
17 |
from modules.utils import *
|
18 |
+
from modules.config import local_embedding
|
19 |
+
|
20 |
|
21 |
def get_index_name(file_src):
|
22 |
file_paths = [x.name for x in file_src]
|
|
|
30 |
|
31 |
return md5_hash.hexdigest()
|
32 |
|
33 |
+
|
34 |
def block_split(text):
|
35 |
blocks = []
|
36 |
while len(text) > 0:
|
|
|
38 |
text = text[1000:]
|
39 |
return blocks
|
40 |
|
41 |
+
|
42 |
def get_documents(file_src):
|
43 |
documents = []
|
44 |
logging.debug("Loading documents...")
|
|
|
54 |
try:
|
55 |
from modules.pdf_func import parse_pdf
|
56 |
from modules.config import advance_docs
|
57 |
+
|
58 |
two_column = advance_docs["pdf"].get("two_column", False)
|
59 |
pdftext = parse_pdf(filepath, two_column).text
|
60 |
except:
|
61 |
pdftext = ""
|
62 |
+
with open(filepath, "rb") as pdfFileObj:
|
63 |
pdfReader = PyPDF2.PdfReader(pdfFileObj)
|
64 |
for page in tqdm(pdfReader.pages):
|
65 |
pdftext += page.extract_text()
|
|
|
96 |
|
97 |
|
98 |
def construct_index(
|
99 |
+
api_key,
|
100 |
+
file_src,
|
101 |
+
max_input_size=4096,
|
102 |
+
num_outputs=5,
|
103 |
+
max_chunk_overlap=20,
|
104 |
+
chunk_size_limit=600,
|
105 |
+
embedding_limit=None,
|
106 |
+
separator=" ",
|
107 |
):
|
108 |
from langchain.chat_models import ChatOpenAI
|
109 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
110 |
+
from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
|
111 |
|
112 |
+
if api_key:
|
113 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
114 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
115 |
embedding_limit = None if embedding_limit == 0 else embedding_limit
|
116 |
separator = " " if separator == "" else separator
|
|
|
118 |
llm_predictor = LLMPredictor(
|
119 |
llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
|
120 |
)
|
121 |
+
prompt_helper = PromptHelper(
|
122 |
+
max_input_size=max_input_size,
|
123 |
+
num_output=num_outputs,
|
124 |
+
max_chunk_overlap=max_chunk_overlap,
|
125 |
+
embedding_limit=embedding_limit,
|
126 |
+
chunk_size_limit=600,
|
127 |
+
separator=separator,
|
128 |
+
)
|
129 |
index_name = get_index_name(file_src)
|
130 |
if os.path.exists(f"./index/{index_name}.json"):
|
131 |
logging.info("找到了缓存的索引文件,加载中……")
|
|
|
133 |
else:
|
134 |
try:
|
135 |
documents = get_documents(file_src)
|
136 |
+
if local_embedding:
|
137 |
+
embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
|
138 |
+
else:
|
139 |
+
embed_model = OpenAIEmbedding()
|
140 |
logging.info("构建索引中……")
|
141 |
with retrieve_proxy():
|
142 |
+
service_context = ServiceContext.from_defaults(
|
143 |
+
llm_predictor=llm_predictor,
|
144 |
+
prompt_helper=prompt_helper,
|
145 |
+
chunk_size_limit=chunk_size_limit,
|
146 |
+
embed_model=LangchainEmbedding(HuggingFaceEmbeddings()),
|
147 |
+
)
|
148 |
index = GPTSimpleVectorIndex.from_documents(
|
149 |
+
documents, service_context=service_context
|
150 |
)
|
151 |
logging.debug("索引构建完成!")
|
152 |
os.makedirs("./index", exist_ok=True)
|
modules/models.py
CHANGED
@@ -30,6 +30,7 @@ from .llama_func import *
|
|
30 |
from .utils import *
|
31 |
from . import shared
|
32 |
from .config import retrieve_proxy
|
|
|
33 |
from .base_model import BaseLLMModel, ModelType
|
34 |
|
35 |
|
@@ -379,6 +380,8 @@ class ModelManager:
|
|
379 |
msg = f"模型设置为了: {model_name}"
|
380 |
logging.info(msg)
|
381 |
model_type = ModelType.get_type(model_name)
|
|
|
|
|
382 |
if model_type == ModelType.OpenAI:
|
383 |
model = OpenAIClient(
|
384 |
model_name=model_name,
|
|
|
30 |
from .utils import *
|
31 |
from . import shared
|
32 |
from .config import retrieve_proxy
|
33 |
+
from modules import config
|
34 |
from .base_model import BaseLLMModel, ModelType
|
35 |
|
36 |
|
|
|
380 |
msg = f"模型设置为了: {model_name}"
|
381 |
logging.info(msg)
|
382 |
model_type = ModelType.get_type(model_name)
|
383 |
+
if model_type != ModelType.OpenAI:
|
384 |
+
config.local_embedding = True
|
385 |
if model_type == ModelType.OpenAI:
|
386 |
model = OpenAIClient(
|
387 |
model_name=model_name,
|
requirements.txt
CHANGED
@@ -19,3 +19,4 @@ mpi4py
|
|
19 |
icetk
|
20 |
git+https://github.com/OptimalScale/LMFlow.git
|
21 |
cpm-kernels
|
|
|
|
19 |
icetk
|
20 |
git+https://github.com/OptimalScale/LMFlow.git
|
21 |
cpm-kernels
|
22 |
+
sentence_transformers
|