Tuchuanhuhuhu commited on
Commit
2d5d187
·
1 Parent(s): 0127941

支持本地embedding

Browse files
modules/base_model.py CHANGED
@@ -132,8 +132,8 @@ class BaseLLMModel:
132
  status_text = self.token_message()
133
  yield get_return_value()
134
  if self.interrupted:
135
- self.recover()
136
- break
137
  self.history.append(construct_assistant(partial_text))
138
 
139
  def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
@@ -170,7 +170,14 @@ class BaseLLMModel:
170
  ): # repetition_penalty, top_k
171
  from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
172
  from llama_index.indices.query.schema import QueryBundle
173
- from langchain.llms import OpenAIChat
 
 
 
 
 
 
 
174
 
175
  logging.info(
176
  "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
@@ -182,20 +189,22 @@ class BaseLLMModel:
182
  old_inputs = None
183
  display_reference = []
184
  limited_context = False
185
- if files and self.api_key:
186
  limited_context = True
187
  old_inputs = inputs
188
  msg = "加载索引中……(这可能需要几分钟)"
189
  logging.info(msg)
190
  yield chatbot + [(inputs, "")], msg
191
  index = construct_index(self.api_key, file_src=files)
 
192
  msg = "索引构建完成,获取回答中……"
 
 
 
 
193
  logging.info(msg)
194
  yield chatbot + [(inputs, "")], msg
195
  with retrieve_proxy():
196
- llm_predictor = LLMPredictor(
197
- llm=OpenAIChat(temperature=0, model_name=self.model_name)
198
- )
199
  prompt_helper = PromptHelper(
200
  max_input_size=4096,
201
  num_output=5,
@@ -205,7 +214,7 @@ class BaseLLMModel:
205
  from llama_index import ServiceContext
206
 
207
  service_context = ServiceContext.from_defaults(
208
- llm_predictor=llm_predictor, prompt_helper=prompt_helper
209
  )
210
  query_object = GPTVectorStoreIndexQuery(
211
  index.index_struct,
@@ -249,7 +258,11 @@ class BaseLLMModel:
249
  else:
250
  display_reference = ""
251
 
252
- if self.api_key is not None and len(self.api_key) == 0 and not shared.state.multi_api_key:
 
 
 
 
253
  status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
254
  logging.info(status_text)
255
  chatbot.append((inputs, ""))
 
132
  status_text = self.token_message()
133
  yield get_return_value()
134
  if self.interrupted:
135
+ self.recover()
136
+ break
137
  self.history.append(construct_assistant(partial_text))
138
 
139
  def next_chatbot_at_once(self, inputs, chatbot, fake_input=None, display_append=""):
 
170
  ): # repetition_penalty, top_k
171
  from llama_index.indices.vector_store.base_query import GPTVectorStoreIndexQuery
172
  from llama_index.indices.query.schema import QueryBundle
173
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
174
+ from langchain.chat_models import ChatOpenAI
175
+ from llama_index import (
176
+ GPTSimpleVectorIndex,
177
+ ServiceContext,
178
+ LangchainEmbedding,
179
+ OpenAIEmbedding,
180
+ )
181
 
182
  logging.info(
183
  "输入为:" + colorama.Fore.BLUE + f"{inputs}" + colorama.Style.RESET_ALL
 
189
  old_inputs = None
190
  display_reference = []
191
  limited_context = False
192
+ if files:
193
  limited_context = True
194
  old_inputs = inputs
195
  msg = "加载索引中……(这可能需要几分钟)"
196
  logging.info(msg)
197
  yield chatbot + [(inputs, "")], msg
198
  index = construct_index(self.api_key, file_src=files)
199
+ assert index is not None, "索引构建失败"
200
  msg = "索引构建完成,获取回答中……"
201
+ if local_embedding:
202
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
203
+ else:
204
+ embed_model = OpenAIEmbedding()
205
  logging.info(msg)
206
  yield chatbot + [(inputs, "")], msg
207
  with retrieve_proxy():
 
 
 
208
  prompt_helper = PromptHelper(
209
  max_input_size=4096,
210
  num_output=5,
 
214
  from llama_index import ServiceContext
215
 
216
  service_context = ServiceContext.from_defaults(
217
+ prompt_helper=prompt_helper, embed_model=embed_model
218
  )
219
  query_object = GPTVectorStoreIndexQuery(
220
  index.index_struct,
 
258
  else:
259
  display_reference = ""
260
 
261
+ if (
262
+ self.api_key is not None
263
+ and len(self.api_key) == 0
264
+ and not shared.state.multi_api_key
265
+ ):
266
  status_text = STANDARD_ERROR_MSG + NO_APIKEY_MSG
267
  logging.info(status_text)
268
  chatbot.append((inputs, ""))
modules/config.py CHANGED
@@ -117,6 +117,8 @@ https_proxy = os.environ.get("HTTPS_PROXY", https_proxy)
117
  os.environ["HTTP_PROXY"] = ""
118
  os.environ["HTTPS_PROXY"] = ""
119
 
 
 
120
  @contextmanager
121
  def retrieve_proxy(proxy=None):
122
  """
 
117
  os.environ["HTTP_PROXY"] = ""
118
  os.environ["HTTPS_PROXY"] = ""
119
 
120
+ local_embedding = config.get("local_embedding", False) # 是否使用本地embedding
121
+
122
  @contextmanager
123
  def retrieve_proxy(proxy=None):
124
  """
modules/llama_func.py CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
15
 
16
  from modules.presets import *
17
  from modules.utils import *
 
 
18
 
19
  def get_index_name(file_src):
20
  file_paths = [x.name for x in file_src]
@@ -28,6 +30,7 @@ def get_index_name(file_src):
28
 
29
  return md5_hash.hexdigest()
30
 
 
31
  def block_split(text):
32
  blocks = []
33
  while len(text) > 0:
@@ -35,6 +38,7 @@ def block_split(text):
35
  text = text[1000:]
36
  return blocks
37
 
 
38
  def get_documents(file_src):
39
  documents = []
40
  logging.debug("Loading documents...")
@@ -50,11 +54,12 @@ def get_documents(file_src):
50
  try:
51
  from modules.pdf_func import parse_pdf
52
  from modules.config import advance_docs
 
53
  two_column = advance_docs["pdf"].get("two_column", False)
54
  pdftext = parse_pdf(filepath, two_column).text
55
  except:
56
  pdftext = ""
57
- with open(filepath, 'rb') as pdfFileObj:
58
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
59
  for page in tqdm(pdfReader.pages):
60
  pdftext += page.extract_text()
@@ -91,19 +96,21 @@ def get_documents(file_src):
91
 
92
 
93
  def construct_index(
94
- api_key,
95
- file_src,
96
- max_input_size=4096,
97
- num_outputs=5,
98
- max_chunk_overlap=20,
99
- chunk_size_limit=600,
100
- embedding_limit=None,
101
- separator=" "
102
  ):
103
  from langchain.chat_models import ChatOpenAI
104
- from llama_index import GPTSimpleVectorIndex, ServiceContext
 
105
 
106
- os.environ["OPENAI_API_KEY"] = api_key
 
107
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
108
  embedding_limit = None if embedding_limit == 0 else embedding_limit
109
  separator = " " if separator == "" else separator
@@ -111,7 +118,14 @@ def construct_index(
111
  llm_predictor = LLMPredictor(
112
  llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
113
  )
114
- prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
 
 
 
 
 
 
 
115
  index_name = get_index_name(file_src)
116
  if os.path.exists(f"./index/{index_name}.json"):
117
  logging.info("找到了缓存的索引文件,加载中……")
@@ -119,11 +133,20 @@ def construct_index(
119
  else:
120
  try:
121
  documents = get_documents(file_src)
 
 
 
 
122
  logging.info("构建索引中……")
123
  with retrieve_proxy():
124
- service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
 
 
 
 
 
125
  index = GPTSimpleVectorIndex.from_documents(
126
- documents, service_context=service_context
127
  )
128
  logging.debug("索引构建完成!")
129
  os.makedirs("./index", exist_ok=True)
 
15
 
16
  from modules.presets import *
17
  from modules.utils import *
18
+ from modules.config import local_embedding
19
+
20
 
21
  def get_index_name(file_src):
22
  file_paths = [x.name for x in file_src]
 
30
 
31
  return md5_hash.hexdigest()
32
 
33
+
34
  def block_split(text):
35
  blocks = []
36
  while len(text) > 0:
 
38
  text = text[1000:]
39
  return blocks
40
 
41
+
42
  def get_documents(file_src):
43
  documents = []
44
  logging.debug("Loading documents...")
 
54
  try:
55
  from modules.pdf_func import parse_pdf
56
  from modules.config import advance_docs
57
+
58
  two_column = advance_docs["pdf"].get("two_column", False)
59
  pdftext = parse_pdf(filepath, two_column).text
60
  except:
61
  pdftext = ""
62
+ with open(filepath, "rb") as pdfFileObj:
63
  pdfReader = PyPDF2.PdfReader(pdfFileObj)
64
  for page in tqdm(pdfReader.pages):
65
  pdftext += page.extract_text()
 
96
 
97
 
98
  def construct_index(
99
+ api_key,
100
+ file_src,
101
+ max_input_size=4096,
102
+ num_outputs=5,
103
+ max_chunk_overlap=20,
104
+ chunk_size_limit=600,
105
+ embedding_limit=None,
106
+ separator=" ",
107
  ):
108
  from langchain.chat_models import ChatOpenAI
109
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
110
+ from llama_index import GPTSimpleVectorIndex, ServiceContext, LangchainEmbedding, OpenAIEmbedding
111
 
112
+ if api_key:
113
+ os.environ["OPENAI_API_KEY"] = api_key
114
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
115
  embedding_limit = None if embedding_limit == 0 else embedding_limit
116
  separator = " " if separator == "" else separator
 
118
  llm_predictor = LLMPredictor(
119
  llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
120
  )
121
+ prompt_helper = PromptHelper(
122
+ max_input_size=max_input_size,
123
+ num_output=num_outputs,
124
+ max_chunk_overlap=max_chunk_overlap,
125
+ embedding_limit=embedding_limit,
126
+ chunk_size_limit=600,
127
+ separator=separator,
128
+ )
129
  index_name = get_index_name(file_src)
130
  if os.path.exists(f"./index/{index_name}.json"):
131
  logging.info("找到了缓存的索引文件,加载中……")
 
133
  else:
134
  try:
135
  documents = get_documents(file_src)
136
+ if local_embedding:
137
+ embed_model = LangchainEmbedding(HuggingFaceEmbeddings())
138
+ else:
139
+ embed_model = OpenAIEmbedding()
140
  logging.info("构建索引中……")
141
  with retrieve_proxy():
142
+ service_context = ServiceContext.from_defaults(
143
+ llm_predictor=llm_predictor,
144
+ prompt_helper=prompt_helper,
145
+ chunk_size_limit=chunk_size_limit,
146
+ embed_model=LangchainEmbedding(HuggingFaceEmbeddings()),
147
+ )
148
  index = GPTSimpleVectorIndex.from_documents(
149
+ documents, service_context=service_context
150
  )
151
  logging.debug("索引构建完成!")
152
  os.makedirs("./index", exist_ok=True)
modules/models.py CHANGED
@@ -30,6 +30,7 @@ from .llama_func import *
30
  from .utils import *
31
  from . import shared
32
  from .config import retrieve_proxy
 
33
  from .base_model import BaseLLMModel, ModelType
34
 
35
 
@@ -379,6 +380,8 @@ class ModelManager:
379
  msg = f"模型设置为了: {model_name}"
380
  logging.info(msg)
381
  model_type = ModelType.get_type(model_name)
 
 
382
  if model_type == ModelType.OpenAI:
383
  model = OpenAIClient(
384
  model_name=model_name,
 
30
  from .utils import *
31
  from . import shared
32
  from .config import retrieve_proxy
33
+ from modules import config
34
  from .base_model import BaseLLMModel, ModelType
35
 
36
 
 
380
  msg = f"模型设置为了: {model_name}"
381
  logging.info(msg)
382
  model_type = ModelType.get_type(model_name)
383
+ if model_type != ModelType.OpenAI:
384
+ config.local_embedding = True
385
  if model_type == ModelType.OpenAI:
386
  model = OpenAIClient(
387
  model_name=model_name,
requirements.txt CHANGED
@@ -19,3 +19,4 @@ mpi4py
19
  icetk
20
  git+https://github.com/OptimalScale/LMFlow.git
21
  cpm-kernels
 
 
19
  icetk
20
  git+https://github.com/OptimalScale/LMFlow.git
21
  cpm-kernels
22
+ sentence_transformers