Gainward777 commited on
Commit
d0241f1
·
verified ·
1 Parent(s): 1de93cb

Update llm/utils.py

Browse files
Files changed (1) hide show
  1. llm/utils.py +74 -74
llm/utils.py CHANGED
@@ -1,74 +1,74 @@
1
- from langchain_community.llms import HuggingFaceEndpoint
2
- from langchain.memory import ConversationBufferMemory
3
- from langchain.chains import ConversationalRetrievalChain
4
- import gradio as gr
5
- import os
6
- from CustomRetriever import CustomRetriever
7
-
8
-
9
- API_TOKEN=os.getenv("TOKEN")
10
-
11
- # Initialize langchain LLM chain
12
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
13
- thold=0.8, progress=gr.Progress()):
14
- llm = HuggingFaceEndpoint(
15
- huggingfacehub_api_token = API_TOKEN,
16
- repo_id=llm_model,
17
- temperature = temperature,
18
- max_new_tokens = max_tokens,
19
- top_k = top_k,
20
- )
21
-
22
- memory = ConversationBufferMemory(
23
- memory_key="chat_history",
24
- output_key='answer',
25
- return_messages=True
26
- )
27
-
28
- retr=CustomRetriever(vdb, thold=thold)
29
- retriever=retr.retriever #vector_db.as_retriever()
30
- qa_chain = ConversationalRetrievalChain.from_llm(
31
- llm,
32
- retriever=retriever,
33
- chain_type="stuff",
34
- memory=memory,
35
- return_source_documents=True,
36
- verbose=False,
37
- )
38
- return qa_chain
39
-
40
-
41
-
42
- # Initialize LLM
43
- def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()):
44
- # print("llm_option",llm_option)
45
- llm_name = "mistralai/Mistral-7B-Instruct-v0.2" #list_llm[llm_option]
46
- #print("llm_name: ",llm_name)
47
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold)
48
- return qa_chain #, "QA chain initialized. Chatbot is ready!"
49
-
50
-
51
-
52
- def format_chat_history(chat_history):#message, chat_history): #no need message
53
- formatted_chat_history = []
54
- for user_message, bot_message in chat_history:
55
- formatted_chat_history.append(f"User: {user_message}")
56
- formatted_chat_history.append(f"Assistant: {bot_message}")
57
- return formatted_chat_history
58
-
59
-
60
-
61
- def postprocess(response):
62
- try:
63
- result=response["answer"]
64
- for doc in response['source_documents']:
65
- file_doc="\n\nFile: " + doc.metadata["source"]
66
- page="\nPage: " + str(doc.metadata["page"])
67
- content="\nFragment: " + doc.page_content.strip()
68
- result+=file_doc+page+content
69
- return result
70
- except:
71
- return response["answer"]
72
-
73
-
74
-
 
1
+ from langchain_community.llms import HuggingFaceEndpoint
2
+ from langchain.memory import ConversationBufferMemory
3
+ from langchain.chains import ConversationalRetrievalChain
4
+ import gradio as gr
5
+ import os
6
+ from CustomRetriever import RetrieverWithScores #CustomRetriever
7
+
8
+
9
+ API_TOKEN=os.getenv("TOKEN")
10
+
11
+ # Initialize langchain LLM chain
12
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vdb,
13
+ thold=0.8, progress=gr.Progress()):
14
+ llm = HuggingFaceEndpoint(
15
+ huggingfacehub_api_token = API_TOKEN,
16
+ repo_id=llm_model,
17
+ temperature = temperature,
18
+ max_new_tokens = max_tokens,
19
+ top_k = top_k,
20
+ )
21
+
22
+ memory = ConversationBufferMemory(
23
+ memory_key="chat_history",
24
+ output_key='answer',
25
+ return_messages=True
26
+ )
27
+
28
+ #retr=CustomRetriever(vdb, thold=thold)
29
+ retriever=RetrieverWithScores(vdb, thold=thold) #retr.retriever #vector_db.as_retriever()
30
+ qa_chain = ConversationalRetrievalChain.from_llm(
31
+ llm,
32
+ retriever=retriever,
33
+ chain_type="stuff",
34
+ memory=memory,
35
+ return_source_documents=True,
36
+ verbose=False,
37
+ )
38
+ return qa_chain
39
+
40
+
41
+
42
+ # Initialize LLM
43
+ def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, thold, progress=gr.Progress()):
44
+ # print("llm_option",llm_option)
45
+ llm_name = "mistralai/Mistral-7B-Instruct-v0.2" #list_llm[llm_option]
46
+ #print("llm_name: ",llm_name)
47
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, thold)
48
+ return qa_chain #, "QA chain initialized. Chatbot is ready!"
49
+
50
+
51
+
52
+ def format_chat_history(chat_history):#message, chat_history): #no need message
53
+ formatted_chat_history = []
54
+ for user_message, bot_message in chat_history:
55
+ formatted_chat_history.append(f"User: {user_message}")
56
+ formatted_chat_history.append(f"Assistant: {bot_message}")
57
+ return formatted_chat_history
58
+
59
+
60
+
61
+ def postprocess(response):
62
+ try:
63
+ result=response["answer"]
64
+ for doc in response['source_documents']:
65
+ file_doc="\n\nFile: " + doc.metadata["source"]
66
+ page="\nPage: " + str(doc.metadata["page"])
67
+ content="\nFragment: " + doc.page_content.strip()
68
+ result+=file_doc+page+content
69
+ return result
70
+ except:
71
+ return response["answer"]
72
+
73
+
74
+