Update app.py
Browse filesTranslate prompt and extract longer context chunks
app.py
CHANGED
@@ -41,28 +41,53 @@ def respond(
|
|
41 |
|
42 |
print(datetime.now())
|
43 |
print(system_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
# retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
|
46 |
-
retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
|
47 |
# retriever = vector_db.as_retriever(search_type="mmr")
|
48 |
-
documents = retriever.invoke(message)
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
|
|
|
|
|
|
50 |
spacer = " \n"
|
51 |
context = ""
|
|
|
52 |
|
53 |
#print(message)
|
54 |
-
print(len(documents))
|
55 |
|
56 |
for doc in documents:
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
context += "#######" + spacer
|
60 |
-
context += "# Case number: " + doc.metadata["case_nb"] + spacer
|
61 |
-
context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
|
62 |
-
context += "# Case date: " + doc.metadata["case_date"] + spacer
|
63 |
-
context += "# Case url: " + doc.metadata["case_url"] + spacer
|
64 |
-
context += "# Case text: " + doc.page_content + spacer
|
65 |
-
|
66 |
|
67 |
#print("# Case number: " + doc.metadata["case_nb"] + spacer)
|
68 |
#print("# Case url: " + doc.metadata["case_url"] + spacer)
|
@@ -114,7 +139,7 @@ demo = gr.ChatInterface(
|
|
114 |
additional_inputs=[
|
115 |
gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
|
116 |
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
|
117 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0, step=0.1, label="Temperature"),
|
118 |
gr.Slider(
|
119 |
minimum=0.1,
|
120 |
maximum=1.0,
|
|
|
41 |
|
42 |
print(datetime.now())
|
43 |
print(system_message)
|
44 |
+
|
45 |
+
prompt_en = client.text_generation("Improve or translate the following user's prompt to English giving only the new prompt without explanations or additional text: " + message)
|
46 |
+
prompt_de = client.text_generation("Improve or translate the following user's prompt to German giving only the new prompt without explanations or additional text: " + message)
|
47 |
+
prompt_fr = client.text_generation("Improve or translate the following user's prompt to French giving only the new prompt without explanations or additional text: " + message)
|
48 |
+
prompt_it = client.text_generation("Improve or translate the following user's prompt to Italian giving only the new prompt without explanations or additional text: " + message)
|
49 |
+
|
50 |
+
print(prompt_en)
|
51 |
+
print(prompt_de)
|
52 |
+
print(prompt_fr)
|
53 |
+
print(prompt_it)
|
54 |
+
|
55 |
|
56 |
# retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
|
57 |
+
# retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
|
58 |
# retriever = vector_db.as_retriever(search_type="mmr")
|
59 |
+
# documents = retriever.invoke(message)
|
60 |
+
|
61 |
+
documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
|
62 |
+
documents_de = vector_db.similarity_search_with_score(prompt_de, k=4)
|
63 |
+
documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
|
64 |
+
documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
|
65 |
|
66 |
+
documents = documents_en + documents_de + documents_fr + documents_it
|
67 |
+
|
68 |
+
documents = sorted(documents, key=lambda x: x[1])[:4]
|
69 |
+
|
70 |
spacer = " \n"
|
71 |
context = ""
|
72 |
+
nb_char = 2000
|
73 |
|
74 |
#print(message)
|
75 |
+
print(f"* Documents found: {len(documents)}")
|
76 |
|
77 |
for doc in documents:
|
78 |
+
case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
|
79 |
+
index = case_text.find(doc[0].page_content)
|
80 |
+
start = max(0, index - nb_char)
|
81 |
+
end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
|
82 |
+
case_text_summary = case_text[start:end]
|
83 |
+
|
84 |
context += "#######" + spacer
|
85 |
+
context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
|
86 |
+
context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
|
87 |
+
context += "# Case date: " + doc[0].metadata["case_date"] + spacer
|
88 |
+
context += "# Case url: " + doc[0].metadata["case_url"] + spacer
|
89 |
+
#context += "# Case text: " + doc[0].page_content + spacer
|
90 |
+
context += "Case extract: " + case_text_summary + spacer
|
91 |
|
92 |
#print("# Case number: " + doc.metadata["case_nb"] + spacer)
|
93 |
#print("# Case url: " + doc.metadata["case_url"] + spacer)
|
|
|
139 |
additional_inputs=[
|
140 |
gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
|
141 |
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
|
142 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
|
143 |
gr.Slider(
|
144 |
minimum=0.1,
|
145 |
maximum=1.0,
|