File size: 6,501 Bytes
5c60ed2 0bf6060 f8adcff 1080fd6 0635997 7c12ef4 d5c54ef 037376c 0635997 f846748 35f9142 60ac7f7 213b4a3 60ac7f7 f846748 0635997 79c456d 0635997 11fc4a7 0635997 8661441 d5c54ef f846748 c3ef985 f846748 2e4ad5e e6d12c5 037376c e6d12c5 a6051b9 a1e734a a6051b9 c3d6f33 a6051b9 47761f1 a6051b9 f7848c9 a6051b9 f7848c9 a6051b9 f7848c9 a6051b9 f7848c9 443f706 a6051b9 d2eb5fb 309b510 a6051b9 5a4546c ae1f860 a6051b9 309b510 7cc0278 a6051b9 d3c72ad a6051b9 5a4546c ae1f860 ee3e1d1 1b8a611 6045274 a0ff377 e6d12c5 1ab3741 a0ff377 6045274 ae1f860 2e4ad5e 4e67249 34f414b a2933d7 f846748 b7d6ba3 d2eb5fb a6051b9 f846748 4c92796 f846748 cddcba8 f846748 cdec1a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)
#model = "meta-llama/Llama-3.2-1B-Instruct"
#model = "google/mt5-small"
model = "mistralai/Mistral-7B-Instruct-v0.3"
client = InferenceClient(model)
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
score,
):
#messages = [{"role": "system", "content": system_message}]
print(datetime.now())
print(system_message)
prompt_template = "Improve or translate the following user's prompt to {language} giving only the new prompt\
without explanations or additional text and if you can't improve it, just return the same prompt, do not extrapolate: "
prompt_en = client.text_generation(prompt_template.format(language="English") + message)
prompt_de = client.text_generation(prompt_template.format(language="German") + message)
prompt_fr = client.text_generation(prompt_template.format(language="French") + message)
prompt_it = client.text_generation(prompt_template.format(language="Italian") + message)
# retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
# retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
# retriever = vector_db.as_retriever(search_type="mmr")
# documents = retriever.invoke(message)
documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
print(prompt_en)
documents_de = vector_db.similarity_search_with_score(prompt_de, k=4)
print(prompt_de)
documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
print(prompt_fr)
documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
print(prompt_it)
documents = documents_en + documents_de + documents_fr + documents_it
documents = sorted(documents, key=lambda x: x[1])[:4]
spacer = " \n"
context = ""
nb_char = 2000
#print(message)
print(f"* Documents found: {len(documents)}")
for doc in documents:
case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
index = case_text.find(doc[0].page_content)
start = max(0, index - nb_char)
end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
case_text_summary = case_text[start:end]
context += "#######" + spacer
context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
context += "# Case date: " + doc[0].metadata["case_date"] + spacer
context += "# Case url: " + doc[0].metadata["case_url"] + spacer
#context += "# Case text: " + doc[0].page_content + spacer
context += "Case extract: " + case_text_summary + spacer
#print("# Case number: " + doc.metadata["case_nb"] + spacer)
#print("# Case url: " + doc.metadata["case_url"] + spacer)
system_message += f"""A user is asking you the following question: {message}
Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet.
# Context:
The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question:
{context}
# Task:
If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information".
Don't make up an answer or give irrelevant information not requested by the user.
Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates.
# Instructions:
Always answer the user using the language used in his question: {message}
"""
print(system_message)
messages = [{"role": "system", "content": system_message}]
# for val in history:
# if val[0]:
# messages.append({"role": "user", "content": val[0]})
# if val[1]:
# messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
],
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
demo.launch(debug=True) |