benticha commited on
Commit
4ff2d98
·
1 Parent(s): 6c63a94

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.sqlite3 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain import memory as lc_memory
3
+ from langsmith import Client
4
+ from streamlit_feedback import streamlit_feedback
5
+ from utils import get_expression_chain, retriever
6
+ from langchain_core.tracers.context import collect_runs
7
+ from dotenv import load_dotenv
8
+
9
+ load_dotenv()
10
+ client = Client()
11
+
12
+ st.set_page_config(page_title = "SUP'ASSISTANT")
13
+ st.subheader("Hey there! How can I help you today!")
14
+
15
+ memory = lc_memory.ConversationBufferMemory(
16
+ chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"),
17
+ return_messages=True,
18
+ memory_key="chat_history",
19
+ )
20
+ st.sidebar.markdown("## Feedback Scale")
21
+ feedback_option = (
22
+ "thumbs" if st.sidebar.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "faces"
23
+ )
24
+ with st.sidebar:
25
+ model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"])
26
+ temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001)
27
+ n_docs = st.number_input("**Number of retireved documents**", min_value=0, max_value=10, value=5, step=1)
28
+ if st.sidebar.button("Clear message history"):
29
+ print("Clearing message history")
30
+ memory.clear()
31
+
32
+ retriever = retriever(n_docs=n_docs)
33
+ # Create Chain
34
+ chain = get_expression_chain(retriever,model_name,temp)
35
+
36
+ for msg in st.session_state.langchain_messages:
37
+ avatar = "🦜" if msg.type == "ai" else None
38
+ with st.chat_message(msg.type, avatar=avatar):
39
+ st.markdown(msg.content)
40
+
41
+ if prompt := st.chat_input(placeholder="Ask me a question!"):
42
+ st.chat_message("user").write(prompt)
43
+ with st.chat_message("assistant", avatar="🦜"):
44
+ message_placeholder = st.empty()
45
+ full_response = ""
46
+ # Define the basic input structure for the chains
47
+ input_dict = {"input": prompt}
48
+
49
+ with collect_runs() as cb:
50
+ for chunk in chain.stream(input_dict, config={"tags": ["Streamlit Chat"]}):
51
+ full_response += chunk.content
52
+ message_placeholder.markdown(full_response + "▌")
53
+ memory.save_context(input_dict, {"output": full_response})
54
+ st.session_state.run_id = cb.traced_runs[0].id
55
+ message_placeholder.markdown(full_response)
56
+
57
+ if st.session_state.get("run_id"):
58
+ run_id = st.session_state.run_id
59
+ feedback = streamlit_feedback(
60
+ feedback_type=feedback_option,
61
+ optional_text_label="[Optional] Please provide an explanation",
62
+ key=f"feedback_{run_id}",
63
+ )
64
+
65
+ # Define score mappings for both "thumbs" and "faces" feedback systems
66
+ score_mappings = {
67
+ "thumbs": {"👍": 1, "👎": 0},
68
+ "faces": {"😀": 1, "🙂": 0.75, "😐": 0.5, "🙁": 0.25, "😞": 0},
69
+ }
70
+
71
+ # Get the score mapping based on the selected feedback option
72
+ scores = score_mappings[feedback_option]
73
+
74
+ if feedback:
75
+ # Get the score from the selected feedback option's score mapping
76
+ score = scores.get(feedback["score"])
77
+
78
+ if score is not None:
79
+ # Formulate feedback type string incorporating the feedback option
80
+ # and score value
81
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
82
+
83
+ # Record the feedback with the formulated feedback type string
84
+ # and optional comment
85
+ feedback_record = client.create_feedback(
86
+ run_id,
87
+ feedback_type_str,
88
+ score=score,
89
+ comment=feedback.get("text"),
90
+ )
91
+ st.session_state.feedback = {
92
+ "feedback_id": str(feedback_record.id),
93
+ "score": score,
94
+ }
95
+ else:
96
+ st.warning("Invalid feedback score.")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-groq
2
+ langchain-core
3
+ streamlit
4
+ langchain-chroma
5
+ langchain-nomic
6
+ langchain
7
+ nomic
8
+ python-dotenv
9
+ langchain-community
10
+ rank_bm25
11
+ cohere
12
+ nomic[local]
13
+ streamlit-feedback
sup-knowledge-eng-nomic/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:112a7ee3f7fb675803ed49ffe7901311156373f8ba3142c3a3026b2f3936d633
3
+ size 7704576
sup-knowledge-eng-nomic/ec6754ec-5fa6-4b04-bfb6-d2f052cd81fe/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a13e72541800c513c73dccea69f79e39cf4baef4fa23f7e117c0d6b0f5f99670
3
+ size 3212000
sup-knowledge-eng-nomic/ec6754ec-5fa6-4b04-bfb6-d2f052cd81fe/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ec6df10978b056a10062ed99efeef2702fa4a1301fad702b53dd2517103c746
3
+ size 100
sup-knowledge-eng-nomic/ec6754ec-5fa6-4b04-bfb6-d2f052cd81fe/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5fb64b021f47ff585087f63e019088911fa892704ffa3e9506f3a120d807cfa
3
+ size 4000
sup-knowledge-eng-nomic/ec6754ec-5fa6-4b04-bfb6-d2f052cd81fe/link_lists.bin ADDED
File without changes
utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_chroma import Chroma
2
+ from langchain_nomic.embeddings import NomicEmbeddings
3
+ from langchain_core.documents import Document
4
+ from langchain.retrievers.document_compressors import CohereRerank
5
+ from langchain.retrievers import ContextualCompressionRetriever
6
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
7
+ from langchain_groq import ChatGroq
8
+ from dotenv import load_dotenv
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import Runnable, RunnableMap
11
+ from langchain.schema import BaseRetriever
12
+ load_dotenv()
13
+
14
+ def retriever(n_docs=5):
15
+ vector_database_path = "sup-knowledge-eng-nomic"
16
+
17
+ embeddings_model = NomicEmbeddings(model="nomic-embed-text-v1.5", inference_mode="local")
18
+
19
+
20
+ vectorstore = Chroma(collection_name="sup-store-eng-nomic",
21
+ persist_directory=vector_database_path,
22
+ embedding_function=embeddings_model)
23
+
24
+ vs_retriever = vectorstore.as_retriever(k=n_docs)
25
+
26
+ texts = vectorstore.get()['documents']
27
+ metadatas = vectorstore.get()["metadatas"]
28
+
29
+ documents = []
30
+ for i in range(len(texts)):
31
+ doc = Document(page_content=texts[i], metadata=metadatas[i])
32
+ documents.append(doc)
33
+
34
+ keyword_retriever = BM25Retriever.from_documents(documents)
35
+ keyword_retriever.k = n_docs
36
+
37
+ ensemble_retriever = EnsembleRetriever(retrievers=[vs_retriever,keyword_retriever],
38
+ weights=[0.5, 0.5])
39
+
40
+ compressor = CohereRerank(model="rerank-english-v3.0")
41
+ retriever = ContextualCompressionRetriever(
42
+ base_compressor=compressor, base_retriever=ensemble_retriever
43
+ )
44
+
45
+ return retriever
46
+
47
+ rag_prompt = """You are an assistant for question-answering tasks.
48
+ The questions that you will be asked will mainly be about SUP'COM (also known as Higher School Of Communication Of Tunis).
49
+
50
+ Here is the context to use to answer the question:
51
+
52
+ {context}
53
+
54
+ Think carefully about the above context.
55
+
56
+ Now, review the user question:
57
+
58
+ {input}
59
+
60
+ Provide an answer to this questions using only the above context.
61
+
62
+ Answer:"""
63
+
64
+ # Post-processing
65
+ def format_docs(docs):
66
+ return "\n\n".join(doc.page_content for doc in docs)
67
+
68
+ def get_expression_chain(retriever: BaseRetriever, model_name="llama-3.1-70b-versatile", temp=0
69
+ ) -> Runnable:
70
+ """Return a chain defined primarily in LangChain Expression Language"""
71
+ def retrieve_context(input_text):
72
+ # Use the retriever to fetch relevant documents
73
+ docs = retriever.get_relevant_documents(input_text)
74
+ return format_docs(docs)
75
+
76
+ ingress = RunnableMap(
77
+ {
78
+ "input": lambda x: x["input"],
79
+ "context": lambda x: retrieve_context(x["input"]),
80
+ }
81
+ )
82
+ prompt = ChatPromptTemplate.from_messages(
83
+ [
84
+ (
85
+ "system",
86
+ rag_prompt
87
+ )
88
+ ]
89
+ )
90
+ llm = ChatGroq(model=model_name, temperature=temp)
91
+
92
+ chain = ingress | prompt | llm
93
+ return chain