Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,63 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline
|
3 |
from huggingface_hub import InferenceClient, login, snapshot_download
|
4 |
-
from langchain_community.vectorstores import FAISS
|
5 |
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
import os
|
7 |
import pandas as pd
|
8 |
from datetime import datetime
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
"""
|
12 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
@@ -14,9 +65,7 @@ For more information on `huggingface_hub` Inference API support, please check th
|
|
14 |
HF_TOKEN=os.getenv('TOKEN')
|
15 |
login(HF_TOKEN)
|
16 |
|
17 |
-
|
18 |
-
#model = "google/mt5-small"
|
19 |
-
model = "mistralai/Mistral-7B-Instruct-v0.3"
|
20 |
|
21 |
client = InferenceClient(model)
|
22 |
|
@@ -24,98 +73,36 @@ folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", lo
|
|
24 |
|
25 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
26 |
|
27 |
-
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True)
|
28 |
|
29 |
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
history: list[tuple[str, str]],
|
34 |
-
system_message,
|
35 |
-
max_tokens,
|
36 |
-
temperature,
|
37 |
-
top_p,
|
38 |
-
score,
|
39 |
-
):
|
40 |
-
#messages = [{"role": "system", "content": system_message}]
|
41 |
|
42 |
-
|
43 |
-
print(system_message)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
|
54 |
-
# retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
|
55 |
-
# retriever = vector_db.as_retriever(search_type="mmr")
|
56 |
-
# documents = retriever.invoke(message)
|
57 |
-
|
58 |
-
documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
|
59 |
-
print(prompt_en)
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
|
64 |
-
print(prompt_fr)
|
65 |
-
|
66 |
-
documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
|
67 |
-
print(prompt_it)
|
68 |
-
|
69 |
-
documents = documents_en + documents_de + documents_fr + documents_it
|
70 |
-
|
71 |
-
documents = sorted(documents, key=lambda x: x[1])[:4]
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
nb_char = 2000
|
76 |
-
|
77 |
-
#print(message)
|
78 |
-
print(f"* Documents found: {len(documents)}")
|
79 |
-
|
80 |
-
for doc in documents:
|
81 |
-
case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
|
82 |
-
index = case_text.find(doc[0].page_content)
|
83 |
-
start = max(0, index - nb_char)
|
84 |
-
end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
|
85 |
-
case_text_summary = case_text[start:end]
|
86 |
-
|
87 |
-
context += "#######" + spacer
|
88 |
-
context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
|
89 |
-
context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
|
90 |
-
context += "# Case date: " + doc[0].metadata["case_date"] + spacer
|
91 |
-
context += "# Case url: " + doc[0].metadata["case_url"] + spacer
|
92 |
-
#context += "# Case text: " + doc[0].page_content + spacer
|
93 |
-
context += "Case extract: " + case_text_summary + spacer
|
94 |
-
|
95 |
-
#print("# Case number: " + doc.metadata["case_nb"] + spacer)
|
96 |
-
#print("# Case url: " + doc.metadata["case_url"] + spacer)
|
97 |
-
|
98 |
-
system_message += f"""A user is asking you the following question: {message}
|
99 |
-
Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet.
|
100 |
-
# Context:
|
101 |
-
The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question:
|
102 |
-
{context}
|
103 |
-
# Task:
|
104 |
-
If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information".
|
105 |
-
Don't make up an answer or give irrelevant information not requested by the user.
|
106 |
-
Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates.
|
107 |
-
# Instructions:
|
108 |
-
Always answer the user using the language used in his question: {message}
|
109 |
-
"""
|
110 |
-
|
111 |
-
print(system_message)
|
112 |
-
messages = [{"role": "system", "content": system_message}]
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
119 |
|
120 |
messages.append({"role": "user", "content": message})
|
121 |
|
@@ -129,6 +116,8 @@ Always answer the user using the language used in his question: {message}
|
|
129 |
top_p=top_p,
|
130 |
):
|
131 |
token = message.choices[0].delta.content
|
|
|
|
|
132 |
|
133 |
response += token
|
134 |
yield response
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline
|
3 |
from huggingface_hub import InferenceClient, login, snapshot_download
|
4 |
+
from langchain_community.vectorstores import FAISS, DistanceStrategy
|
5 |
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
import os
|
7 |
import pandas as pd
|
8 |
from datetime import datetime
|
9 |
|
10 |
+
from smolagents import Tool, HfApiModel, ToolCallingAgent
|
11 |
+
from langchain_core.vectorstores import VectorStore
|
12 |
+
|
13 |
+
|
14 |
+
class RetrieverTool(Tool):
|
15 |
+
name = "retriever"
|
16 |
+
description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
|
17 |
+
inputs = {
|
18 |
+
"query": {
|
19 |
+
"type": "string",
|
20 |
+
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
|
21 |
+
}
|
22 |
+
}
|
23 |
+
output_type = "string"
|
24 |
+
|
25 |
+
def __init__(self, vectordb: VectorStore, **kwargs):
|
26 |
+
super().__init__(**kwargs)
|
27 |
+
self.vectordb = vectordb
|
28 |
+
|
29 |
+
def forward(self, query: str) -> str:
|
30 |
+
assert isinstance(query, str), "Your search query must be a string"
|
31 |
+
|
32 |
+
docs = self.vectordb.similarity_search(
|
33 |
+
query,
|
34 |
+
k=7,
|
35 |
+
)
|
36 |
+
|
37 |
+
df = pd.read_csv("bger_cedh_db 1954-2024.csv")
|
38 |
+
|
39 |
+
spacer = " \n"
|
40 |
+
context = ""
|
41 |
+
nb_char = 100
|
42 |
+
|
43 |
+
for doc in docs:
|
44 |
+
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
|
45 |
+
index = case_text.find(doc.page_content)
|
46 |
+
start = max(0, index - nb_char)
|
47 |
+
end = min(len(case_text), index + len(doc.page_content) + nb_char)
|
48 |
+
case_text_summary = case_text[start:end]
|
49 |
+
|
50 |
+
context += "#######" + spacer
|
51 |
+
context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
|
52 |
+
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
|
53 |
+
context += "# Case date: " + doc.metadata["case_date"] + spacer
|
54 |
+
context += "# Case url: " + doc.metadata["case_url"] + spacer
|
55 |
+
#context += "# Case text: " + doc.page_content + spacer
|
56 |
+
context += "# Case extract: " + case_text_summary + spacer
|
57 |
+
|
58 |
+
|
59 |
+
return "\nRetrieved documents:\n" + context
|
60 |
+
|
61 |
|
62 |
"""
|
63 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
|
|
65 |
HF_TOKEN=os.getenv('TOKEN')
|
66 |
login(HF_TOKEN)
|
67 |
|
68 |
+
model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
|
|
|
69 |
|
70 |
client = InferenceClient(model)
|
71 |
|
|
|
73 |
|
74 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
75 |
|
76 |
+
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
|
77 |
|
78 |
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
|
79 |
|
80 |
+
retriever_tool = RetrieverTool(vector_db)
|
81 |
+
agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
|
|
|
84 |
|
85 |
+
print(datetime.now())
|
86 |
+
context = retriever_tool(question)
|
87 |
|
88 |
+
prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
|
89 |
+
Respond only to the question asked, response should be concise and relevant to the question and answer in the same language as the question.
|
90 |
+
Provide the number of the source document when relevant, as well as the link to the document.
|
91 |
+
If you cannot find information, do not give up and try calling your retriever again with different arguments!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
Question:
|
94 |
+
{question}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
{context}
|
97 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
messages = [{"role": "user", "content": prompt}]
|
100 |
+
|
101 |
+
for val in history:
|
102 |
+
if val[0]:
|
103 |
+
messages.append({"role": "user", "content": val[0]})
|
104 |
+
if val[1]:
|
105 |
+
messages.append({"role": "assistant", "content": val[1]})
|
106 |
|
107 |
messages.append({"role": "user", "content": message})
|
108 |
|
|
|
116 |
top_p=top_p,
|
117 |
):
|
118 |
token = message.choices[0].delta.content
|
119 |
+
|
120 |
+
# answer = client.chat_completion(messages, temperature=0.1).choices[0].message.content
|
121 |
|
122 |
response += token
|
123 |
yield response
|