umaiku commited on
Commit
a6051b9
·
verified ·
1 Parent(s): 79c456d

Update app.py

Browse files

Translate prompt and extract longer context chunks

Files changed (1) hide show
  1. app.py +37 -12
app.py CHANGED
@@ -41,28 +41,53 @@ def respond(
41
 
42
  print(datetime.now())
43
  print(system_message)
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
46
- retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
47
  # retriever = vector_db.as_retriever(search_type="mmr")
48
- documents = retriever.invoke(message)
 
 
 
 
 
49
 
 
 
 
 
50
  spacer = " \n"
51
  context = ""
 
52
 
53
  #print(message)
54
- print(len(documents))
55
 
56
  for doc in documents:
57
- #case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
58
-
 
 
 
 
59
  context += "#######" + spacer
60
- context += "# Case number: " + doc.metadata["case_nb"] + spacer
61
- context += "# Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
62
- context += "# Case date: " + doc.metadata["case_date"] + spacer
63
- context += "# Case url: " + doc.metadata["case_url"] + spacer
64
- context += "# Case text: " + doc.page_content + spacer
65
- #context += "Case text: " + case_text[:8000] + spacer
66
 
67
  #print("# Case number: " + doc.metadata["case_nb"] + spacer)
68
  #print("# Case url: " + doc.metadata["case_url"] + spacer)
@@ -114,7 +139,7 @@ demo = gr.ChatInterface(
114
  additional_inputs=[
115
  gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
116
  gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
117
- gr.Slider(minimum=0.1, maximum=4.0, value=0, step=0.1, label="Temperature"),
118
  gr.Slider(
119
  minimum=0.1,
120
  maximum=1.0,
 
41
 
42
  print(datetime.now())
43
  print(system_message)
44
+
45
+ prompt_en = client.text_generation("Improve or translate the following user's prompt to English giving only the new prompt without explanations or additional text: " + message)
46
+ prompt_de = client.text_generation("Improve or translate the following user's prompt to German giving only the new prompt without explanations or additional text: " + message)
47
+ prompt_fr = client.text_generation("Improve or translate the following user's prompt to French giving only the new prompt without explanations or additional text: " + message)
48
+ prompt_it = client.text_generation("Improve or translate the following user's prompt to Italian giving only the new prompt without explanations or additional text: " + message)
49
+
50
+ print(prompt_en)
51
+ print(prompt_de)
52
+ print(prompt_fr)
53
+ print(prompt_it)
54
+
55
 
56
  # retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
57
+ # retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
58
  # retriever = vector_db.as_retriever(search_type="mmr")
59
+ # documents = retriever.invoke(message)
60
+
61
+ documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
62
+ documents_de = vector_db.similarity_search_with_score(prompt_de, k=4)
63
+ documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
64
+ documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
65
 
66
+ documents = documents_en + documents_de + documents_fr + documents_it
67
+
68
+ documents = sorted(documents, key=lambda x: x[1])[:4]
69
+
70
  spacer = " \n"
71
  context = ""
72
+ nb_char = 2000
73
 
74
  #print(message)
75
+ print(f"* Documents found: {len(documents)}")
76
 
77
  for doc in documents:
78
+ case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
79
+ index = case_text.find(doc[0].page_content)
80
+ start = max(0, index - nb_char)
81
+ end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
82
+ case_text_summary = case_text[start:end]
83
+
84
  context += "#######" + spacer
85
+ context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
86
+ context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
87
+ context += "# Case date: " + doc[0].metadata["case_date"] + spacer
88
+ context += "# Case url: " + doc[0].metadata["case_url"] + spacer
89
+ #context += "# Case text: " + doc[0].page_content + spacer
90
+ context += "Case extract: " + case_text_summary + spacer
91
 
92
  #print("# Case number: " + doc.metadata["case_nb"] + spacer)
93
  #print("# Case url: " + doc.metadata["case_url"] + spacer)
 
139
  additional_inputs=[
140
  gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
141
  gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
142
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
143
  gr.Slider(
144
  minimum=0.1,
145
  maximum=1.0,