Gainward777 commited on
Commit
1364090
·
verified ·
1 Parent(s): b7cc73a

Update llm/utils.py

Browse files
Files changed (1) hide show
  1. llm/utils.py +6 -35
llm/utils.py CHANGED
@@ -3,7 +3,7 @@ from langchain.memory import ConversationBufferMemory
3
  from langchain.chains import ConversationalRetrievalChain
4
  import gradio as gr
5
  import os
6
- #from CustomRetriever import CustomRetriever
7
 
8
  from langchain.schema.retriever import BaseRetriever
9
  from langchain_core.documents import Document
@@ -14,41 +14,14 @@ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
14
  from langchain_core.documents import Document
15
  from langchain_core.runnables import chain
16
 
17
-
18
  API_TOKEN=os.getenv("TOKEN")
19
 
20
 
21
- #Because of bugs in pydantic it is not possible to take it out retr_func and RetrieverWithScores into a separate neat class.
22
- #It is necessary to use dirty implementation through global variables.
23
- VDB=None
24
- THOLD=0.7
25
-
26
- @chain
27
- def retr_func(query: str)-> List[Document]:
28
-
29
- docs, scores = zip(*VDB.similarity_search_with_relevance_scores(query))#similarity_search_with_score(query))
30
- result=[]
31
- for doc, score in zip(docs, scores):
32
- if score>THOLD:
33
- doc.metadata["score"] = score
34
- result.append(doc)
35
- if len(result)==0:
36
- result.append(Document(metadata={}, page_content='No data'))
37
-
38
- return result
39
-
40
-
41
- class RetrieverWithScores(BaseRetriever):
42
- def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun)-> List[Document]:
43
- return retr_func.invoke(query)
44
-
45
-
46
-
47
  # Initialize langchain LLM chain
48
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
49
  thold=0.8, progress=gr.Progress()):
50
- global VDB
51
- global THOLD
52
 
53
  llm = HuggingFaceEndpoint(
54
  huggingfacehub_api_token = API_TOKEN,
@@ -64,13 +37,11 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
64
  return_messages=True
65
  )
66
 
67
- VDB=vdb
68
- THOLD=thold
69
- #retr=CustomRetriever(vdb, thold=thold)
70
- #retriever=retr.retriever
71
  qa_chain = ConversationalRetrievalChain.from_llm(
72
  llm,
73
- retriever=RetrieverWithScores(),#retriever,
74
  chain_type="stuff",
75
  memory=memory,
76
  return_source_documents=True,
 
3
  from langchain.chains import ConversationalRetrievalChain
4
  import gradio as gr
5
  import os
6
+ from llm.CustomRetriever import CustomRetriever
7
 
8
  from langchain.schema.retriever import BaseRetriever
9
  from langchain_core.documents import Document
 
14
  from langchain_core.documents import Document
15
  from langchain_core.runnables import chain
16
 
 
17
  API_TOKEN=os.getenv("TOKEN")
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Initialize langchain LLM chain
21
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
22
  thold=0.8, progress=gr.Progress()):
23
+ #global VDB
24
+ #global THOLD
25
 
26
  llm = HuggingFaceEndpoint(
27
  huggingfacehub_api_token = API_TOKEN,
 
37
  return_messages=True
38
  )
39
 
40
+ #VDB=vdb
41
+ #THOLD=thold
 
 
42
  qa_chain = ConversationalRetrievalChain.from_llm(
43
  llm,
44
+ retriever=CustomRetriever(vectorstore=vdb, thold=thold),#RetrieverWithScores(),
45
  chain_type="stuff",
46
  memory=memory,
47
  return_source_documents=True,