JaganathC commited on
Commit
7c21ccc
·
verified ·
1 Parent(s): 3ddcd56

Update retrieval.py

Browse files
Files changed (1) hide show
  1. retrieval.py +18 -46
retrieval.py CHANGED
@@ -2,9 +2,6 @@
2
  LLM chain retrieval
3
  """
4
 
5
-
6
-
7
-
8
  import json
9
  import gradio as gr
10
 
@@ -14,16 +11,6 @@ from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_core.prompts import PromptTemplate
15
 
16
 
17
- # Add system template for RAG application
18
- PROMPT_TEMPLATE = """
19
- You are an assistant for question-answering tasks. Use the following pieces of context to answer the question at the end.
20
- If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer concise.
21
- Question: {question}
22
- Context: {context}
23
- Helpful Answer:
24
- """
25
-
26
-
27
  # Initialize langchain LLM chain
28
  def initialize_llmchain(
29
  llm_model,
@@ -37,22 +24,11 @@ def initialize_llmchain(
37
  """Initialize Langchain LLM chain"""
38
 
39
  progress(0.1, desc="Initializing HF tokenizer...")
40
- # HuggingFaceHub uses HF inference endpoints
41
  progress(0.5, desc="Initializing HF Hub...")
42
- # Use of trust_remote_code as model_kwargs
43
- # Warning: langchain issue
44
- # URL: https://github.com/langchain-ai/langchain/issues/6080
45
-
46
- # if 'Llama' in llm_model:
47
- # task = "conversational"
48
- # else:
49
- # task = "text-generation"
50
- # print(f"Task: {task}")
51
 
52
  llm = HuggingFaceEndpoint(
53
  repo_id=llm_model,
54
  task="text-generation",
55
- #task="conversational",
56
  provider="hf-inference",
57
  temperature=temperature,
58
  max_new_tokens=max_tokens,
@@ -62,18 +38,20 @@ def initialize_llmchain(
62
 
63
  progress(0.75, desc="Defining buffer memory...")
64
  memory = ConversationBufferMemory(
65
- memory_key="chat_history", output_key="answer", return_messages=True
 
 
66
  )
67
- # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
68
- retriever = vector_db.as_retriever()
69
 
70
  progress(0.8, desc="Defining retrieval chain...")
71
  with open('prompt_template.json', 'r') as file:
72
- system_prompt = json.load(file)
73
  prompt_template = system_prompt["prompt"]
74
  rag_prompt = PromptTemplate(
75
  template=prompt_template, input_variables=["context", "question"]
76
  )
 
77
  qa_chain = ConversationalRetrievalChain.from_llm(
78
  llm,
79
  retriever=retriever,
@@ -81,17 +59,16 @@ def initialize_llmchain(
81
  memory=memory,
82
  combine_docs_chain_kwargs={"prompt": rag_prompt},
83
  return_source_documents=True,
84
- # return_generated_question=False,
85
  verbose=False,
86
  )
87
- progress(0.9, desc="Done!")
88
 
 
89
  return qa_chain
90
 
91
 
 
92
  def format_chat_history(message, chat_history):
93
- """Format chat history for llm chain"""
94
-
95
  formatted_chat_history = []
96
  for user_message, bot_message in chat_history:
97
  formatted_chat_history.append(f"User: {user_message}")
@@ -99,27 +76,22 @@ def format_chat_history(message, chat_history):
99
  return formatted_chat_history
100
 
101
 
 
102
  def invoke_qa_chain(qa_chain, message, history):
103
  """Invoke question-answering chain"""
104
-
105
  formatted_chat_history = format_chat_history(message, history)
106
- # print("formatted_chat_history",formatted_chat_history)
107
 
108
- # Generate response using QA chain
109
- response = qa_chain.invoke(
110
- {"question": message, "chat_history": formatted_chat_history}
111
- )
112
 
113
  response_sources = response["source_documents"]
114
-
115
  response_answer = response["answer"]
116
- if response_answer.find("Helpful Answer:") != -1:
117
- response_answer = response_answer.split("Helpful Answer:")[-1]
118
-
119
- # Append user message and response to chat history
120
- new_history = history + [(message, response_answer)]
121
 
122
- # print ('chat response: ', response_answer)
123
- # print('DB source', response_sources)
 
124
 
 
125
  return qa_chain, new_history, response_sources
 
2
  LLM chain retrieval
3
  """
4
 
 
 
 
5
  import json
6
  import gradio as gr
7
 
 
11
  from langchain_core.prompts import PromptTemplate
12
 
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Initialize langchain LLM chain
15
  def initialize_llmchain(
16
  llm_model,
 
24
  """Initialize Langchain LLM chain"""
25
 
26
  progress(0.1, desc="Initializing HF tokenizer...")
 
27
  progress(0.5, desc="Initializing HF Hub...")
 
 
 
 
 
 
 
 
 
28
 
29
  llm = HuggingFaceEndpoint(
30
  repo_id=llm_model,
31
  task="text-generation",
 
32
  provider="hf-inference",
33
  temperature=temperature,
34
  max_new_tokens=max_tokens,
 
38
 
39
  progress(0.75, desc="Defining buffer memory...")
40
  memory = ConversationBufferMemory(
41
+ memory_key="chat_history",
42
+ output_key="answer",
43
+ return_messages=True,
44
  )
45
+ retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={'k': top_k})
 
46
 
47
  progress(0.8, desc="Defining retrieval chain...")
48
  with open('prompt_template.json', 'r') as file:
49
+ system_prompt = json.load(file)
50
  prompt_template = system_prompt["prompt"]
51
  rag_prompt = PromptTemplate(
52
  template=prompt_template, input_variables=["context", "question"]
53
  )
54
+
55
  qa_chain = ConversationalRetrievalChain.from_llm(
56
  llm,
57
  retriever=retriever,
 
59
  memory=memory,
60
  combine_docs_chain_kwargs={"prompt": rag_prompt},
61
  return_source_documents=True,
 
62
  verbose=False,
63
  )
 
64
 
65
+ progress(0.9, desc="Done!")
66
  return qa_chain
67
 
68
 
69
+ # Format chat history
70
  def format_chat_history(message, chat_history):
71
+ """Format chat history for LLM"""
 
72
  formatted_chat_history = []
73
  for user_message, bot_message in chat_history:
74
  formatted_chat_history.append(f"User: {user_message}")
 
76
  return formatted_chat_history
77
 
78
 
79
+ # Invoke QA chain with history
80
  def invoke_qa_chain(qa_chain, message, history):
81
  """Invoke question-answering chain"""
 
82
  formatted_chat_history = format_chat_history(message, history)
 
83
 
84
+ response = qa_chain.invoke({
85
+ "question": message,
86
+ "chat_history": formatted_chat_history,
87
+ })
88
 
89
  response_sources = response["source_documents"]
 
90
  response_answer = response["answer"]
 
 
 
 
 
91
 
92
+ # Clean up if "Helpful Answer:" is included
93
+ if "Helpful Answer:" in response_answer:
94
+ response_answer = response_answer.split("Helpful Answer:")[-1].strip()
95
 
96
+ new_history = history + [(message, response_answer)]
97
  return qa_chain, new_history, response_sources