File size: 3,622 Bytes
5c60ed2
f8adcff
0635997
 
7c12ef4
0635997
 
f846748
 
 
cd7ca86
185d396
57b0d16
f846748
0635997
 
 
 
 
 
f846748
 
 
 
 
 
 
c3ef985
f846748
5b9e4ac
f846748
c3ef985
7e5bae2
443f706
7e5bae2
a37b742
 
 
7e5bae2
 
6334495
 
309b510
 
 
 
 
 
 
 
7e5bae2
 
 
309b510
 
 
 
7e5bae2
 
1cafcb9
a37b742
be65967
34f414b
 
 
 
 
 
a2933d7
f846748
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1fef0d
b70a745
f846748
 
 
 
 
 
 
 
c3ef985
f846748
cddcba8
f846748
 
 
 
cdec1a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
import os


"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
login(token=os.getenv('TOKEN'))
client = InferenceClient("meta-llama/Llama-3.2-1B-Instruct")
#client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())

embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-small")

vector_db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    score,
):
    messages = [{"role": "system", "content": system_message}]

    retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score})
    documents = retriever.invoke(message)

    """
    if document == []:
        message = message + "\nNo cases were found about this subject"
    else:
        message = message + "\nUse the following jurisprudence case to answer " + documents[0].page_content + "\n Give the following url " + documents[0].metadata["case_url"]
    """

    spacer = " \n "

    context = ""

    for doc in documents:
        context += "Case number: " + doc.metadata["case_nb"] + "\n"
        context += "Case date: " + doc.metadata["case_date"] + "\n"
        context += "Case url: " + doc.metadata["case_url"] + "\n"
        context += "Case chunk: " + doc.page_content + "\n"
    
    message = f"""
        The user is asking for information about the following: {message}.
        Answer him in his own language using the information from the following Swiss federal jurisprudence cases:
        {context}
        Please mention your sources in your answer.
        If you don't know just mention the sources.
        
    """
    
    print(message)

#    for val in history:
#        if val[0]:
#            messages.append({"role": "user", "content": val[0]})
#        if val[1]:
#            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
        gr.Slider(minimum=1, maximum=24000, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Score Threshold"),
    ],
    description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)


if __name__ == "__main__":
    demo.launch(debug=True)