File size: 6,501 Bytes
5c60ed2
0bf6060
f8adcff
1080fd6
0635997
7c12ef4
d5c54ef
037376c
0635997
 
f846748
 
 
35f9142
 
60ac7f7
 
213b4a3
 
60ac7f7
 
f846748
0635997
 
79c456d
0635997
11fc4a7
0635997
8661441
d5c54ef
f846748
 
 
 
 
 
 
c3ef985
f846748
2e4ad5e
e6d12c5
037376c
e6d12c5
a6051b9
a1e734a
 
 
 
 
 
 
a6051b9
c3d6f33
a6051b9
47761f1
a6051b9
 
 
f7848c9
 
a6051b9
f7848c9
a6051b9
f7848c9
 
a6051b9
f7848c9
443f706
a6051b9
 
 
 
d2eb5fb
309b510
a6051b9
5a4546c
ae1f860
a6051b9
309b510
7cc0278
a6051b9
 
 
 
 
 
d3c72ad
a6051b9
 
 
 
 
 
5a4546c
ae1f860
 
ee3e1d1
1b8a611
6045274
 
 
 
a0ff377
e6d12c5
 
1ab3741
a0ff377
6045274
 
ae1f860
2e4ad5e
 
4e67249
34f414b
 
 
 
 
 
a2933d7
f846748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7d6ba3
d2eb5fb
a6051b9
f846748
 
 
 
 
 
 
4c92796
f846748
cddcba8
f846748
 
 
 
cdec1a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS, DistanceStrategy
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
from datetime import datetime


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)

#model = "meta-llama/Llama-3.2-1B-Instruct"
#model = "google/mt5-small"
model = "mistralai/Mistral-7B-Instruct-v0.3"

client = InferenceClient(model)

folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2")

vector_db = FAISS.load_local("faiss_index_mpnet_cos", embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.COSINE)

df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    score,
):
    #messages = [{"role": "system", "content": system_message}]

    print(datetime.now())
    print(system_message)

    prompt_template = "Improve or translate the following user's prompt to {language} giving only the new prompt\
    without explanations or additional text and if you can't improve it, just return the same prompt, do not extrapolate: "
    
    prompt_en = client.text_generation(prompt_template.format(language="English") + message)
    prompt_de = client.text_generation(prompt_template.format(language="German") + message)
    prompt_fr = client.text_generation(prompt_template.format(language="French") + message)
    prompt_it = client.text_generation(prompt_template.format(language="Italian") + message)

#    retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score, "k": 10})
#    retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={"k": 10})
#    retriever = vector_db.as_retriever(search_type="mmr")
#    documents = retriever.invoke(message)

    documents_en = vector_db.similarity_search_with_score(prompt_en, k=4)
    print(prompt_en)
    
    documents_de = vector_db.similarity_search_with_score(prompt_de, k=4)
    print(prompt_de)
    documents_fr = vector_db.similarity_search_with_score(prompt_fr, k=4)
    print(prompt_fr)

    documents_it = vector_db.similarity_search_with_score(prompt_it, k=4)
    print(prompt_it)

    documents = documents_en + documents_de + documents_fr + documents_it

    documents = sorted(documents, key=lambda x: x[1])[:4]
    
    spacer = " \n"
    context = ""
    nb_char = 2000

    #print(message)
    print(f"* Documents found: {len(documents)}")

    for doc in documents:
        case_text = df[df["case_url"] == doc[0].metadata["case_url"]].case_text.values[0]
        index = case_text.find(doc[0].page_content)
        start = max(0, index - nb_char)
        end = min(len(case_text), index + len(doc[0].page_content) + nb_char)
        case_text_summary = case_text[start:end]
        
        context += "#######" + spacer
        context += "# Case number: " + doc[0].metadata["case_nb"] + spacer
        context += "# Case source: " + ("Swiss Federal Court" if doc[0].metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
        context += "# Case date: " + doc[0].metadata["case_date"] + spacer
        context += "# Case url: " + doc[0].metadata["case_url"] + spacer
        #context += "# Case text: " + doc[0].page_content + spacer
        context += "Case extract: " + case_text_summary + spacer

        #print("# Case number: " + doc.metadata["case_nb"] + spacer)
        #print("# Case url: " + doc.metadata["case_url"] + spacer)
   
    system_message += f"""A user is asking you the following question: {message}
Please answer the user in the same language that he used in his question using ONLY the following given context not any prior knowledge or information found on the internet.
# Context:
The following case extracts have been found in either Swiss Federal Court or European Court of Human Rights cases and could fit the question:
{context}
# Task:
If the retrieved context is not relevant cases or the issue has not been addressed within the context, just say "I can't find enough relevant information".
Don't make up an answer or give irrelevant information not requested by the user.
Otherwise, if relevant cases were found, answer in the user's question's language using the context that you found relevant and reference the sources, including the urls and dates.
# Instructions:
Always answer the user using the language used in his question: {message}
"""

    print(system_message)
    messages = [{"role": "system", "content": system_message}]
    
#    for val in history:
#        if val[0]:
#            messages.append({"role": "user", "content": val[0]})
#        if val[1]:
#            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are assisting a jurist or a layer in finding relevant Swiss Jurisprudence cases to their question.", label="System message"),
        gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0, maximum=1, value=0.75, step=0.05, label="Score Threshold"),
    ],
    description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)


if __name__ == "__main__":
    demo.launch(debug=True)