Commit
·
44846b2
1
Parent(s):
8ad9e26
Upload 7 files
Browse files- modules/chat_func.py +2 -2
- modules/llama_func.py +42 -38
modules/chat_func.py
CHANGED
|
@@ -155,7 +155,7 @@ def stream_predict(
|
|
| 155 |
yield get_return_value()
|
| 156 |
error_json_str = ""
|
| 157 |
|
| 158 |
-
for chunk in response.iter_lines():
|
| 159 |
if counter == 0:
|
| 160 |
counter += 1
|
| 161 |
continue
|
|
@@ -272,7 +272,7 @@ def predict(
|
|
| 272 |
if reply_language == "跟随问题语言(不稳定)":
|
| 273 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
| 274 |
if files:
|
| 275 |
-
msg = "
|
| 276 |
logging.info(msg)
|
| 277 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
| 278 |
index = construct_index(openai_api_key, file_src=files)
|
|
|
|
| 155 |
yield get_return_value()
|
| 156 |
error_json_str = ""
|
| 157 |
|
| 158 |
+
for chunk in tqdm(response.iter_lines()):
|
| 159 |
if counter == 0:
|
| 160 |
counter += 1
|
| 161 |
continue
|
|
|
|
| 272 |
if reply_language == "跟随问题语言(不稳定)":
|
| 273 |
reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
|
| 274 |
if files:
|
| 275 |
+
msg = "加载索引中……(这可能需要几分钟)"
|
| 276 |
logging.info(msg)
|
| 277 |
yield chatbot+[(inputs, "")], history, msg, all_token_counts
|
| 278 |
index = construct_index(openai_api_key, file_src=files)
|
modules/llama_func.py
CHANGED
|
@@ -13,54 +13,57 @@ from llama_index import (
|
|
| 13 |
from langchain.llms import OpenAI
|
| 14 |
import colorama
|
| 15 |
|
| 16 |
-
|
| 17 |
from modules.presets import *
|
| 18 |
from modules.utils import *
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def get_documents(file_src):
|
| 22 |
documents = []
|
| 23 |
-
index_name = ""
|
| 24 |
logging.debug("Loading documents...")
|
| 25 |
logging.debug(f"file_src: {file_src}")
|
| 26 |
for file in file_src:
|
| 27 |
-
logging.
|
| 28 |
-
index_name += file.name
|
| 29 |
if os.path.splitext(file.name)[1] == ".pdf":
|
| 30 |
logging.debug("Loading PDF...")
|
| 31 |
CJKPDFReader = download_loader("CJKPDFReader")
|
| 32 |
loader = CJKPDFReader()
|
| 33 |
-
|
| 34 |
elif os.path.splitext(file.name)[1] == ".docx":
|
| 35 |
logging.debug("Loading DOCX...")
|
| 36 |
DocxReader = download_loader("DocxReader")
|
| 37 |
loader = DocxReader()
|
| 38 |
-
|
| 39 |
elif os.path.splitext(file.name)[1] == ".epub":
|
| 40 |
logging.debug("Loading EPUB...")
|
| 41 |
EpubReader = download_loader("EpubReader")
|
| 42 |
loader = EpubReader()
|
| 43 |
-
|
| 44 |
else:
|
| 45 |
logging.debug("Loading text file...")
|
| 46 |
with open(file.name, "r", encoding="utf-8") as f:
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
return documents
|
| 51 |
|
| 52 |
|
| 53 |
def construct_index(
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
):
|
| 65 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 66 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
|
@@ -78,12 +81,13 @@ def construct_index(
|
|
| 78 |
chunk_size_limit,
|
| 79 |
separator=separator,
|
| 80 |
)
|
| 81 |
-
|
| 82 |
if os.path.exists(f"./index/{index_name}.json"):
|
| 83 |
logging.info("找到了缓存的索引文件,加载中……")
|
| 84 |
return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
|
| 85 |
else:
|
| 86 |
try:
|
|
|
|
| 87 |
logging.debug("构建索引中……")
|
| 88 |
index = GPTSimpleVectorIndex(
|
| 89 |
documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
|
|
@@ -97,12 +101,12 @@ def construct_index(
|
|
| 97 |
|
| 98 |
|
| 99 |
def chat_ai(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
):
|
| 107 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 108 |
|
|
@@ -133,15 +137,15 @@ def chat_ai(
|
|
| 133 |
|
| 134 |
|
| 135 |
def ask_ai(
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
):
|
| 146 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 147 |
|
|
@@ -174,7 +178,7 @@ def ask_ai(
|
|
| 174 |
for index, node in enumerate(response.source_nodes):
|
| 175 |
brief = node.source_text[:25].replace("\n", "")
|
| 176 |
nodes.append(
|
| 177 |
-
f"<details><summary>[{index+1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
|
| 178 |
)
|
| 179 |
new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
|
| 180 |
logging.info(
|
|
|
|
| 13 |
from langchain.llms import OpenAI
|
| 14 |
import colorama
|
| 15 |
|
|
|
|
| 16 |
from modules.presets import *
|
| 17 |
from modules.utils import *
|
| 18 |
|
| 19 |
+
def get_index_name(file_src):
|
| 20 |
+
index_name = ""
|
| 21 |
+
for file in file_src:
|
| 22 |
+
index_name += os.path.basename(file.name)
|
| 23 |
+
index_name = sha1sum(index_name)
|
| 24 |
+
return index_name
|
| 25 |
|
| 26 |
def get_documents(file_src):
|
| 27 |
documents = []
|
|
|
|
| 28 |
logging.debug("Loading documents...")
|
| 29 |
logging.debug(f"file_src: {file_src}")
|
| 30 |
for file in file_src:
|
| 31 |
+
logging.info(f"loading file: {file.name}")
|
|
|
|
| 32 |
if os.path.splitext(file.name)[1] == ".pdf":
|
| 33 |
logging.debug("Loading PDF...")
|
| 34 |
CJKPDFReader = download_loader("CJKPDFReader")
|
| 35 |
loader = CJKPDFReader()
|
| 36 |
+
text_raw = loader.load_data(file=file.name)[0].text
|
| 37 |
elif os.path.splitext(file.name)[1] == ".docx":
|
| 38 |
logging.debug("Loading DOCX...")
|
| 39 |
DocxReader = download_loader("DocxReader")
|
| 40 |
loader = DocxReader()
|
| 41 |
+
text_raw = loader.load_data(file=file.name)[0].text
|
| 42 |
elif os.path.splitext(file.name)[1] == ".epub":
|
| 43 |
logging.debug("Loading EPUB...")
|
| 44 |
EpubReader = download_loader("EpubReader")
|
| 45 |
loader = EpubReader()
|
| 46 |
+
text_raw = loader.load_data(file=file.name)[0].text
|
| 47 |
else:
|
| 48 |
logging.debug("Loading text file...")
|
| 49 |
with open(file.name, "r", encoding="utf-8") as f:
|
| 50 |
+
text_raw = f.read()
|
| 51 |
+
text = add_space(text_raw)
|
| 52 |
+
documents += [Document(text)]
|
| 53 |
+
return documents
|
| 54 |
|
| 55 |
|
| 56 |
def construct_index(
|
| 57 |
+
api_key,
|
| 58 |
+
file_src,
|
| 59 |
+
max_input_size=4096,
|
| 60 |
+
num_outputs=1,
|
| 61 |
+
max_chunk_overlap=20,
|
| 62 |
+
chunk_size_limit=600,
|
| 63 |
+
embedding_limit=None,
|
| 64 |
+
separator=" ",
|
| 65 |
+
num_children=10,
|
| 66 |
+
max_keywords_per_chunk=10,
|
| 67 |
):
|
| 68 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 69 |
chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
|
|
|
|
| 81 |
chunk_size_limit,
|
| 82 |
separator=separator,
|
| 83 |
)
|
| 84 |
+
index_name = get_index_name(file_src)
|
| 85 |
if os.path.exists(f"./index/{index_name}.json"):
|
| 86 |
logging.info("找到了缓存的索引文件,加载中……")
|
| 87 |
return GPTSimpleVectorIndex.load_from_disk(f"./index/{index_name}.json")
|
| 88 |
else:
|
| 89 |
try:
|
| 90 |
+
documents = get_documents(file_src)
|
| 91 |
logging.debug("构建索引中……")
|
| 92 |
index = GPTSimpleVectorIndex(
|
| 93 |
documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def chat_ai(
|
| 104 |
+
api_key,
|
| 105 |
+
index,
|
| 106 |
+
question,
|
| 107 |
+
context,
|
| 108 |
+
chatbot,
|
| 109 |
+
reply_language,
|
| 110 |
):
|
| 111 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 112 |
|
|
|
|
| 137 |
|
| 138 |
|
| 139 |
def ask_ai(
|
| 140 |
+
api_key,
|
| 141 |
+
index,
|
| 142 |
+
question,
|
| 143 |
+
prompt_tmpl,
|
| 144 |
+
refine_tmpl,
|
| 145 |
+
sim_k=1,
|
| 146 |
+
temprature=0,
|
| 147 |
+
prefix_messages=[],
|
| 148 |
+
reply_language="中文",
|
| 149 |
):
|
| 150 |
os.environ["OPENAI_API_KEY"] = api_key
|
| 151 |
|
|
|
|
| 178 |
for index, node in enumerate(response.source_nodes):
|
| 179 |
brief = node.source_text[:25].replace("\n", "")
|
| 180 |
nodes.append(
|
| 181 |
+
f"<details><summary>[{index + 1}]\t{brief}...</summary><p>{node.source_text}</p></details>"
|
| 182 |
)
|
| 183 |
new_response = ret_text + "\n----------\n" + "\n\n".join(nodes)
|
| 184 |
logging.info(
|