benticha commited on
Commit
a7d8a51
·
1 Parent(s): 6acad2f

qdrant feature added

Browse files
Files changed (3) hide show
  1. app.py +48 -4
  2. requirements.txt +2 -1
  3. utils.py +20 -2
app.py CHANGED
@@ -2,13 +2,17 @@ import streamlit as st
2
  from langchain import memory as lc_memory
3
  from langsmith import Client
4
  from streamlit_feedback import streamlit_feedback
5
- from utils import get_expression_chain, retriever
6
  from langchain_core.tracers.context import collect_runs
 
7
  from dotenv import load_dotenv
 
8
 
9
  load_dotenv()
10
  client = Client()
11
-
 
 
12
  st.set_page_config(page_title = "SUP'ASSISTANT")
13
  st.subheader("Hey there! How can I help you today!")
14
 
@@ -47,13 +51,33 @@ if prompt := st.chat_input(placeholder="What do you need to know about SUP'COM ?
47
  input_dict = {"input": prompt}
48
 
49
  with collect_runs() as cb:
50
- for chunk in chain.stream(input_dict, config={"tags": ["Streamlit Chat"]}):
51
  full_response += chunk.content
52
  message_placeholder.markdown(full_response + "▌")
53
  memory.save_context(input_dict, {"output": full_response})
54
  st.session_state.run_id = cb.traced_runs[0].id
55
  message_placeholder.markdown(full_response)
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if st.session_state.get("run_id"):
58
  run_id = st.session_state.run_id
59
  feedback = streamlit_feedback(
@@ -93,4 +117,24 @@ if st.session_state.get("run_id"):
93
  "score": score,
94
  }
95
  else:
96
- st.warning("Invalid feedback score.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from langchain import memory as lc_memory
3
  from langsmith import Client
4
  from streamlit_feedback import streamlit_feedback
5
+ from utils import get_expression_chain, retriever, get_embeddings, create_qdrant_collection
6
  from langchain_core.tracers.context import collect_runs
7
+ from qdrant_client import QdrantClient
8
  from dotenv import load_dotenv
9
+ import os
10
 
11
  load_dotenv()
12
  client = Client()
13
+ qdrant_api=os.getenv("QDRANT_API_KEY")
14
+ qdrant_url=os.getenv("QDRANT_URL")
15
+ qdrant_client = QdrantClient(qdrant_url ,api_key=qdrant_api)
16
  st.set_page_config(page_title = "SUP'ASSISTANT")
17
  st.subheader("Hey there! How can I help you today!")
18
 
 
51
  input_dict = {"input": prompt}
52
 
53
  with collect_runs() as cb:
54
+ for chunk in chain.stream(input_dict, config={"tags": ["SUP'ASSISTANT"]}):
55
  full_response += chunk.content
56
  message_placeholder.markdown(full_response + "▌")
57
  memory.save_context(input_dict, {"output": full_response})
58
  st.session_state.run_id = cb.traced_runs[0].id
59
  message_placeholder.markdown(full_response)
60
 
61
+ run_id = st.session_state.run_id
62
+ question_embedding = get_embeddings(prompt)
63
+ answer_embedding = get_embeddings(full_response)
64
+ # Add question and answer to Qdrant
65
+ qdrant_client.upload_collection(
66
+ collection_name="chat-history",
67
+ payload=[
68
+ {"text": prompt, "type": "question", "question_ID": run_id},
69
+ {"text": full_response, "type": "answer", "question_ID": run_id}
70
+ ],
71
+ vectors=[
72
+ question_embedding,
73
+ answer_embedding,
74
+ ],
75
+ parallel=4,
76
+ max_retries=3,
77
+ )
78
+
79
+
80
+
81
  if st.session_state.get("run_id"):
82
  run_id = st.session_state.run_id
83
  feedback = streamlit_feedback(
 
117
  "score": score,
118
  }
119
  else:
120
+ st.warning("Invalid feedback score.")
121
+ if feedback.get("text"):
122
+ comment = feedback.get("text")
123
+ feedback_embedding = get_embeddings(comment)
124
+ else:
125
+ comment = "no comment"
126
+ feedback_embedding = get_embeddings(comment)
127
+
128
+
129
+ qdrant_client.upload_collection(
130
+ collection_name="chat-history",
131
+ payload=[
132
+ {"text": comment,"Score:":score, "type": "feedback", "question_ID": run_id}
133
+ ],
134
+ vectors=[
135
+ feedback_embedding
136
+ ],
137
+ parallel=4,
138
+ max_retries=3,
139
+ )
140
+
requirements.txt CHANGED
@@ -10,4 +10,5 @@ langchain-community
10
  rank_bm25
11
  cohere
12
  nomic[local]
13
- streamlit-feedback
 
 
10
  rank_bm25
11
  cohere
12
  nomic[local]
13
+ streamlit-feedback
14
+ qdrant-client
utils.py CHANGED
@@ -9,8 +9,10 @@ from dotenv import load_dotenv
9
  from langchain_core.prompts import ChatPromptTemplate
10
  from langchain_core.runnables import Runnable, RunnableMap
11
  from langchain.schema import BaseRetriever
12
- load_dotenv()
13
 
 
 
14
  def retriever(n_docs=5):
15
  vector_database_path = "sup-knowledge-eng-nomic"
16
 
@@ -44,6 +46,7 @@ def retriever(n_docs=5):
44
 
45
  return retriever
46
 
 
47
  rag_prompt = """You are an assistant for question-answering tasks.
48
  The questions that you will be asked will mainly be about SUP'COM (also known as Higher School Of Communication Of Tunis).
49
 
@@ -65,6 +68,7 @@ Answer:"""
65
  def format_docs(docs):
66
  return "\n\n".join(doc.page_content for doc in docs)
67
 
 
68
  def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0
69
  ) -> Runnable:
70
  """Return a chain defined primarily in LangChain Expression Language"""
@@ -90,4 +94,18 @@ def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-ver
90
  llm = ChatGroq(model=model_name, temperature=temp)
91
 
92
  chain = ingress | prompt | llm
93
- return chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain_core.prompts import ChatPromptTemplate
10
  from langchain_core.runnables import Runnable, RunnableMap
11
  from langchain.schema import BaseRetriever
12
+ from qdrant_client import models
13
 
14
+ load_dotenv()
15
+ #Retriever
16
  def retriever(n_docs=5):
17
  vector_database_path = "sup-knowledge-eng-nomic"
18
 
 
46
 
47
  return retriever
48
 
49
+ #Retriever prompt
50
  rag_prompt = """You are an assistant for question-answering tasks.
51
  The questions that you will be asked will mainly be about SUP'COM (also known as Higher School Of Communication Of Tunis).
52
 
 
68
  def format_docs(docs):
69
  return "\n\n".join(doc.page_content for doc in docs)
70
 
71
+ #RAG chain
72
  def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0
73
  ) -> Runnable:
74
  """Return a chain defined primarily in LangChain Expression Language"""
 
94
  llm = ChatGroq(model=model_name, temperature=temp)
95
 
96
  chain = ingress | prompt | llm
97
+ return chain
98
+
99
+ embedding_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
100
+ #Generate embeddings for a given text
101
+ def get_embeddings(text):
102
+ return embedding_model.embed([text], task_type='search_document')[0]
103
+
104
+
105
+ # Create or connect to a Qdrant collection
106
+ def create_qdrant_collection(client, collection_name):
107
+ if collection_name not in client.get_collections().collections:
108
+ client.create_collection(
109
+ collection_name=collection_name,
110
+ vectors_config=models.VectorParams(size=768, distance=models.Distance.COSINE)
111
+ )