Spaces:
Sleeping
Sleeping
Update llm/utils.py
Browse files- llm/utils.py +9 -16
llm/utils.py
CHANGED
@@ -18,14 +18,13 @@ from langchain_core.runnables import chain
|
|
18 |
API_TOKEN=os.getenv("TOKEN")
|
19 |
|
20 |
|
21 |
-
|
|
|
22 |
VDB=None
|
23 |
-
|
24 |
THOLD=0.7
|
25 |
|
26 |
@chain
|
27 |
-
def retr_func(query: str)-> List[Document]:
|
28 |
-
#global VDB
|
29 |
|
30 |
docs, scores = zip(*VDB.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
|
31 |
result=[]
|
@@ -36,13 +35,7 @@ def retr_func(query: str)-> List[Document]: #(vdb, query: str)-> List[Document]:
|
|
36 |
if len(result)==0:
|
37 |
result.append(Document(metadata={}, page_content='No data'))
|
38 |
|
39 |
-
|
40 |
-
print(THOLD)
|
41 |
-
print()
|
42 |
-
print(result)
|
43 |
-
print()
|
44 |
-
|
45 |
-
return result #docs
|
46 |
|
47 |
|
48 |
class RetrieverWithScores(BaseRetriever):
|
@@ -74,7 +67,7 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
|
|
74 |
VDB=vdb
|
75 |
THOLD=thold
|
76 |
#retr=CustomRetriever(vdb, thold=thold)
|
77 |
-
#retriever=retr.retriever
|
78 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
79 |
llm,
|
80 |
retriever=RetrieverWithScores(),#retriever,
|
@@ -90,14 +83,14 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
|
|
90 |
# Initialize LLM
|
91 |
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()):
|
92 |
# print("llm_option",llm_option)
|
93 |
-
llm_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
94 |
#print("llm_name: ",llm_name)
|
95 |
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold)
|
96 |
-
return qa_chain
|
97 |
|
98 |
|
99 |
|
100 |
-
def format_chat_history(chat_history)
|
101 |
formatted_chat_history = []
|
102 |
for user_message, bot_message in chat_history:
|
103 |
formatted_chat_history.append(f"User: {user_message}")
|
@@ -118,7 +111,7 @@ def postprocess(response):
|
|
118 |
result+=file_doc+page+content
|
119 |
return result
|
120 |
except:
|
121 |
-
return "I don't know."
|
122 |
|
123 |
|
124 |
|
|
|
18 |
API_TOKEN=os.getenv("TOKEN")
|
19 |
|
20 |
|
21 |
+
#Because of bugs in pydantic it is not possible to take it out retr_func and RetrieverWithScores into a separate neat class.
|
22 |
+
#It is necessary to use dirty implementation through global variables.
|
23 |
VDB=None
|
|
|
24 |
THOLD=0.7
|
25 |
|
26 |
@chain
|
27 |
+
def retr_func(query: str)-> List[Document]:
|
|
|
28 |
|
29 |
docs, scores = zip(*VDB.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
|
30 |
result=[]
|
|
|
35 |
if len(result)==0:
|
36 |
result.append(Document(metadata={}, page_content='No data'))
|
37 |
|
38 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
class RetrieverWithScores(BaseRetriever):
|
|
|
67 |
VDB=vdb
|
68 |
THOLD=thold
|
69 |
#retr=CustomRetriever(vdb, thold=thold)
|
70 |
+
#retriever=retr.retriever
|
71 |
qa_chain = ConversationalRetrievalChain.from_llm(
|
72 |
llm,
|
73 |
retriever=RetrieverWithScores(),#retriever,
|
|
|
83 |
# Initialize LLM
|
84 |
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()):
|
85 |
# print("llm_option",llm_option)
|
86 |
+
llm_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
87 |
#print("llm_name: ",llm_name)
|
88 |
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold)
|
89 |
+
return qa_chain
|
90 |
|
91 |
|
92 |
|
93 |
+
def format_chat_history(chat_history):
|
94 |
formatted_chat_history = []
|
95 |
for user_message, bot_message in chat_history:
|
96 |
formatted_chat_history.append(f"User: {user_message}")
|
|
|
111 |
result+=file_doc+page+content
|
112 |
return result
|
113 |
except:
|
114 |
+
return "I don't know."
|
115 |
|
116 |
|
117 |
|