Phoenix21 commited on
Commit
0dad17d
·
verified ·
1 Parent(s): 4f0c5f2

Update my_memory_logic.py

Browse files
Files changed (1) hide show
  1. my_memory_logic.py +60 -43
my_memory_logic.py CHANGED
@@ -1,47 +1,64 @@
1
  # my_memory_logic.py
2
  import os
3
- from langchain.memory import ConversationBufferMemory
4
- from langchain.chains import LLMChain
5
- from langchain.prompts.chat import (
6
- ChatPromptTemplate,
7
- SystemMessagePromptTemplate,
8
- MessagesPlaceholder,
9
- HumanMessagePromptTemplate,
10
- )
11
- # Import ChatGroq from the langchain_groq package
12
- from langchain_groq import ChatGroq
13
-
14
- # 1) Memory object for storing conversation messages
15
- memory = ConversationBufferMemory(return_messages=True)
16
-
17
- # 2) Restatement system prompt for question rewriting
18
- restatement_system_prompt = """
19
- You have a conversation history plus a new user question.
20
- Your ONLY job is to rewrite the user's latest question so it
21
- makes sense on its own.
22
- - Do NOT repeat or quote large sections of the conversation.
23
- - Do NOT provide any answer or summary.
24
- - Just produce a short, standalone question as your final output.
25
- """
26
-
27
-
28
- # 3) Build the ChatPromptTemplate
29
- restatement_prompt = ChatPromptTemplate.from_messages([
30
- SystemMessagePromptTemplate.from_template(restatement_system_prompt),
31
- MessagesPlaceholder(variable_name="chat_history"),
32
- HumanMessagePromptTemplate.from_template("{input}")
33
- ])
34
-
35
- # 4) Initialize the ChatGroq LLM
36
- # Ensure you have your GROQ_API_KEY set in the environment
37
- restatement_llm = ChatGroq(
38
- model="llama3-70b-8192",
39
- # model="mixtral-8x7b-32768"# or whichever model
40
- groq_api_key=os.environ["GROQ_API_KEY"]
41
- )
42
 
43
- # 5) Create the LLMChain for restatement
44
- restatement_chain = LLMChain(
45
- llm=restatement_llm,
46
- prompt=restatement_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # my_memory_logic.py
2
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # We'll import the session-based classes from langchain_core if you have them installed:
5
+ # If not, you'll need to install the correct package versions or adapt to your environment.
6
+ from langchain_core.chat_history import BaseChatMessageHistory
7
+ from langchain_community.chat_message_histories import ChatMessageHistory
8
+ from langchain_core.runnables.history import RunnableWithMessageHistory
9
+
10
+ # We'll assume you have a `rag_chain` from your pipeline code or can import it.
11
+ # For example:
12
+ # from pipeline import rag_chain
13
+
14
+ # For demonstration, let's just define a dummy "rag_chain" that returns "answer".
15
+ # In your real code, import your actual chain.
16
+ class DummyRagChain:
17
+ def invoke(self, inputs):
18
+ # returns a dictionary with "answer"
19
+ return {"answer": f"Dummy answer to '{inputs['input']}'."}
20
+
21
+ rag_chain = DummyRagChain()
22
+
23
+ ###############################################################################
24
+ # 1) We'll keep an in-memory store of session_id -> ChatMessageHistory
25
+ ###############################################################################
26
+ store = {} # { "abc123": ChatMessageHistory(...) }
27
+
28
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
29
+ """
30
+ Retrieve or create a ChatMessageHistory object for the given session_id.
31
+ """
32
+ if session_id not in store:
33
+ store[session_id] = ChatMessageHistory()
34
+ return store[session_id]
35
+
36
+ ###############################################################################
37
+ # 2) Create the RunnableWithMessageHistory (conversational chain)
38
+ ###############################################################################
39
+ # If your snippet references `rag_chain`, combine it with get_session_history.
40
+ conversational_rag_chain = RunnableWithMessageHistory(
41
+ rag_chain, # your main chain (RAG or pipeline)
42
+ get_session_history, # function to fetch chat history for a session
43
+ input_messages_key="input",
44
+ history_messages_key="chat_history",
45
+ output_messages_key="answer"
46
  )
47
+
48
+ ###############################################################################
49
+ # 3) A convenience function to run a query with session-based memory
50
+ ###############################################################################
51
+ def run_with_session_memory(user_query: str, session_id: str) -> str:
52
+ """
53
+ A convenience wrapper that calls our `conversational_rag_chain`
54
+ with a specific session_id. This returns the final 'answer'.
55
+ """
56
+ response = conversational_rag_chain.invoke(
57
+ {"input": user_query},
58
+ config={
59
+ "configurable": {
60
+ "session_id": session_id
61
+ }
62
+ }
63
+ )
64
+ return response["answer"]