umaiku commited on
Commit
90d1e52
·
verified ·
1 Parent(s): 1b8a611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -87
app.py CHANGED
@@ -1,12 +1,63 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
  from huggingface_hub import InferenceClient, login, snapshot_download
4
- from langchain_community.vectorstores import FAISS
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  import os
7
  import pandas as pd
8
  from datetime import datetime
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  """
12
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
@@ -14,9 +65,7 @@ For more information on `huggingface_hub` Inference API support, please check th
14
  HF_TOKEN=os.getenv('TOKEN')
15
  login(HF_TOKEN)
16
 
17
- #model = "meta-llama/Llama-3.2-1B-Instruct"
18
- #model = "google/mt5-small"
19
- model = "mistralai/Mistral-7B-Instruct-v0.3"
20
 
21
  client = InferenceClient(model)
22
 
@@ -24,98 +73,36 @@ folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", lo
24
 
25
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
26
 
27
- vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True)
28
 
29
  df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
30
 
31
- def respond(
32
- message,
33
- history: list[tuple[str, str]],
34
- system_message,
35
- max_tokens,
36
- temperature,
37
- top_p,
38
- score,
39
- ):
40
- #messages = [{"role": "system", "content": system_message}]
41
 
42
- print(datetime.now())
43
- print(system_message)
44
 
45
- prompt_template = "Improve or translate the following user's prompt to {language} giving only the new prompt\
46
- without explanations or additional text and if you can't improve it, just return the same prompt, do not extrapolate: "
47
 
48
- prompt_en = client.text_generation(prompt_template.format(language="English") + message)
49
- prompt_de = client.text_generation(prompt_template.format(language="German") + message)
50
- prompt_fr = client.text_generation(prompt_template.format(language="French") + message)
51
- prompt_it = client.text_generation(prompt_template.format(language="Italian") + message)
52
-
53
- # retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
54
- # retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
55
- # retriever = vector_db.as_retriever(search_type="mmr")
56
- # documents = retriever.invoke(message)
57
-
58
- documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
59
- print(prompt_en)
60
 
61
- documents_de = vector_db.similarity_search_with_score(prompt_de, k=4)
62
- print(prompt_de)
63
- documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
64
- print(prompt_fr)
65
-
66
- documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
67
- print(prompt_it)
68
-
69
- documents = documents_en + documents_de + documents_fr + documents_it
70
-
71
- documents = sorted(documents, key=lambda x: x[1])[:4]
72
 
73
- spacer = " \n"
74
- context = ""
75
- nb_char = 2000
76
-
77
- #print(message)
78
- print(f"* Documents found: {len(documents)}")
79
-
80
- for doc in documents:
81
- case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
82
- index = case_text.find(doc[0].page_content)
83
- start = max(0, index - nb_char)
84
- end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
85
- case_text_summary = case_text[start:end]
86
-
87
- context += "#######" + spacer
88
- context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
89
- context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
90
- context += "# Case date: " + doc[0].metadata["case_date"] + spacer
91
- context += "# Case url: " + doc[0].metadata["case_url"] + spacer
92
- #context += "# Case text: " + doc[0].page_content + spacer
93
- context += "Case extract: " + case_text_summary + spacer
94
-
95
- #print("# Case number: " + doc.metadata["case_nb"] + spacer)
96
- #print("# Case url: " + doc.metadata["case_url"] + spacer)
97
-
98
- system_message += f"""A user is asking you the following question: {message}
99
- Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet.
100
- # Context:
101
- The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question:
102
- {context}
103
- # Task:
104
- If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information".
105
- Don't make up an answer or give irrelevant information not requested by the user.
106
- Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates.
107
- # Instructions:
108
- Always answer the user using the language used in his question: {message}
109
- """
110
-
111
- print(system_message)
112
- messages = [{"role": "system", "content": system_message}]
113
 
114
- # for val in history:
115
- # if val[0]:
116
- # messages.append({"role": "user", "content": val[0]})
117
- # if val[1]:
118
- # messages.append({"role": "assistant", "content": val[1]})
 
 
119
 
120
  messages.append({"role": "user", "content": message})
121
 
@@ -129,6 +116,8 @@ Always answer the user using the language used in his question: {message}
129
  top_p=top_p,
130
  ):
131
  token = message.choices[0].delta.content
 
 
132
 
133
  response += token
134
  yield response
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  from huggingface_hub import InferenceClient, login, snapshot_download
4
+ from langchain_community.vectorstores import FAISS, DistanceStrategy
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  import os
7
  import pandas as pd
8
  from datetime import datetime
9
 
10
+ from smolagents import Tool, HfApiModel, ToolCallingAgent
11
+ from langchain_core.vectorstores import VectorStore
12
+
13
+
14
+ class RetrieverTool(Tool):
15
+ name = "retriever"
16
+ description = "Using semantic similarity in German, French, English and Italian, retrieves some documents from the knowledge base that have the closest embeddings to the input query."
17
+ inputs = {
18
+ "query": {
19
+ "type": "string",
20
+ "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
21
+ }
22
+ }
23
+ output_type = "string"
24
+
25
+ def __init__(self, vectordb: VectorStore, **kwargs):
26
+ super().__init__(**kwargs)
27
+ self.vectordb = vectordb
28
+
29
+ def forward(self, query: str) -> str:
30
+ assert isinstance(query, str), "Your search query must be a string"
31
+
32
+ docs = self.vectordb.similarity_search(
33
+ query,
34
+ k=7,
35
+ )
36
+
37
+ df = pd.read_csv("bger_cedh_db 1954-2024.csv")
38
+
39
+ spacer = " \n"
40
+ context = ""
41
+ nb_char = 100
42
+
43
+ for doc in docs:
44
+ case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
45
+ index = case_text.find(doc.page_content)
46
+ start = max(0, index - nb_char)
47
+ end = min(len(case_text), index + len(doc.page_content) + nb_char)
48
+ case_text_summary = case_text[start:end]
49
+
50
+ context += "#######" + spacer
51
+ context += "# Case number: " + doc.metadata["case_ref"] + " " + doc.metadata["case_nb"] + spacer
52
+ context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
53
+ context += "# Case date: " + doc.metadata["case_date"] + spacer
54
+ context += "# Case url: " + doc.metadata["case_url"] + spacer
55
+ #context += "# Case text: " + doc.page_content + spacer
56
+ context += "# Case extract: " + case_text_summary + spacer
57
+
58
+
59
+ return "\nRetrieved documents:\n" + context
60
+
61
 
62
  """
63
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
65
  HF_TOKEN=os.getenv('TOKEN')
66
  login(HF_TOKEN)
67
 
68
+ model = "meta-llama/Meta-Llama-3-8B-Instruct"
 
 
69
 
70
  client = InferenceClient(model)
71
 
 
73
 
74
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
75
 
76
+ vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
77
 
78
  df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
79
 
80
+ retriever_tool = RetrieverTool(vector_db)
81
+ agent = ToolCallingAgent(tools=[retriever_tool], model=HfApiModel(model))
 
 
 
 
 
 
 
 
82
 
83
+ def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, score,):
 
84
 
85
+ print(datetime.now())
86
+ context = retriever_tool(question)
87
 
88
+ prompt = f"""Given the question and supporting documents below, give a comprehensive answer to the question.
89
+ Respond only to the question asked, response should be concise and relevant to the question and answer in the same language as the question.
90
+ Provide the number of the source document when relevant, as well as the link to the document.
91
+ If you cannot find information, do not give up and try calling your retriever again with different arguments!
 
 
 
 
 
 
 
 
92
 
93
+ Question:
94
+ {question}
 
 
 
 
 
 
 
 
 
95
 
96
+ {context}
97
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ messages = [{"role": "user", "content": prompt}]
100
+
101
+ for val in history:
102
+ if val[0]:
103
+ messages.append({"role": "user", "content": val[0]})
104
+ if val[1]:
105
+ messages.append({"role": "assistant", "content": val[1]})
106
 
107
  messages.append({"role": "user", "content": message})
108
 
 
116
  top_p=top_p,
117
  ):
118
  token = message.choices[0].delta.content
119
+
120
+ # answer = client.chat_completion(messages, temperature=0.1).choices[0].message.content
121
 
122
  response += token
123
  yield response