Kathirsci commited on
Commit
ae669c2
·
verified ·
1 Parent(s): 9b35aa3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -73
app.py CHANGED
@@ -1,36 +1,11 @@
1
 
2
- import os
3
- from langchain_community.document_loaders import TextLoader
4
- from langchain.vectorstores import Chroma
5
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
6
  from langchain_community.llms import HuggingFaceHub
7
  from langchain.prompts import PromptTemplate
8
- from langchain.memory import ConversationBufferMemory
9
- from langchain.chains import ConversationalRetrievalChain
10
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
- from langchain_core.output_parsers import StrOutputParser
12
- from langchain_core.runnables import RunnablePassthrough
13
- import gradio as gr
14
- import wandb
15
 
16
  # Initialize the chatbot
17
- loaders = []
18
- folder_path = "Data"
19
- for i in range(12):
20
- file_path = os.path.join(folder_path,"{}.txt".format(i))
21
- loaders.append(TextLoader(file_path))
22
- docs = []
23
- for loader in loaders:
24
- docs.extend(loader.load())
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
- embeddings = HuggingFaceInferenceAPIEmbeddings(
27
- api_key=HF_TOKEN,
28
- model_name="sentence-transformers/all-mpnet-base-v2"
29
- )
30
- vectordb = Chroma.from_documents(
31
- documents=docs,
32
- embedding=embeddings
33
- )
34
  llm = HuggingFaceHub(
35
  repo_id="google/gemma-1.1-7b-it",
36
  task="text-generation",
@@ -47,52 +22,12 @@ You are a Mental Health Chatbot. Help the user with their mental health concerns
47
  Use the context below to answer the questions {context}
48
  Question: {question}
49
  Helpful Answer:"""
50
-
51
  QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
52
- memory = ConversationBufferMemory(
53
- memory_key="chat_history",
54
- return_messages=True
55
- )
56
- retriever = vectordb.as_retriever()
57
- qa = ConversationalRetrievalChain.from_llm(
58
- llm,
59
- retriever=retriever,
60
- memory=memory,
61
- )
62
- contextualize_q_system_prompt = """
63
- Given a chat history and the latest user question
64
- which might reference context in the chat history,
65
- formulate a standalone question
66
- which can be understood without the chat history.
67
- Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
68
- contextualize_q_prompt = ChatPromptTemplate.from_messages(
69
- [
70
- ("system", contextualize_q_system_prompt),
71
- MessagesPlaceholder(variable_name="chat_history"),
72
- ("human", "{question}"),
73
- ]
74
- )
75
- contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
76
- def contextualized_question(input: dict):
77
- if input.get("chat_history"):
78
- return contextualize_q_chain
79
- else:
80
- return input["question"]
81
- rag_chain = (
82
- RunnablePassthrough.assign(
83
- context=contextualized_question | retriever
84
- )
85
- | QA_CHAIN_PROMPT
86
- | llm
87
- )
88
- wandb.login(key=os.getenv("key"))
89
- os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
90
- os.environ["WANDB_PROJECT"] = "Mental_Health_ChatBot"
91
- print("Welcome to the Mental Health Chatbot. How can I help you today?")
92
- chat_history = []
93
  def predict(message, history):
94
- ai_msg = rag_chain.invoke({"question": message, "chat_history": chat_history})
95
- idx = ai_msg.find("Answer")
96
- chat_history.extend([HumanMessage(content=message), ai_msg])
97
- return ai_msg[idx:]
98
  gr.ChatInterface(predict).launch()
 
 
 
1
 
2
+
3
+ import gradio as gr
 
 
4
  from langchain_community.llms import HuggingFaceHub
5
  from langchain.prompts import PromptTemplate
 
 
 
 
 
 
 
6
 
7
  # Initialize the chatbot
 
 
 
 
 
 
 
 
8
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
9
  llm = HuggingFaceHub(
10
  repo_id="google/gemma-1.1-7b-it",
11
  task="text-generation",
 
22
  Use the context below to answer the questions {context}
23
  Question: {question}
24
  Helpful Answer:"""
 
25
  QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
26
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def predict(message, history):
28
+ ai_msg = QA_CHAIN_PROMPT.apply({"question": message, "context": history})
29
+ return ai_msg
30
+
 
31
  gr.ChatInterface(predict).launch()
32
+
33
+