File size: 3,549 Bytes
5c60ed2 f8adcff 0635997 7c12ef4 0635997 f846748 cd7ca86 185d396 57b0d16 f846748 0635997 f846748 c3ef985 f846748 5b9e4ac f846748 c3ef985 7e5bae2 443f706 7e5bae2 a37b742 7e5bae2 6334495 309b510 7e5bae2 e061cc2 7e5bae2 1cafcb9 a37b742 be65967 34f414b a2933d7 f846748 d1fef0d e061cc2 f846748 c3ef985 f846748 cddcba8 f846748 cdec1a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
import os
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
login(token=os.getenv('TOKEN'))
client = InferenceClient("meta-llama/Llama-3.2-1B-Instruct")
#client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-small")
vector_db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
score,
):
messages = [{"role": "system", "content": system_message}]
retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score})
documents = retriever.invoke(message)
"""
if document == []:
message = message + "\nNo cases were found about this subject"
else:
message = message + "\nUse the following jurisprudence case to answer " + documents[0].page_content + "\n Give the following url " + documents[0].metadata["case_url"]
"""
spacer = " \n "
context = ""
for doc in documents:
context += "Case number: " + doc.metadata["case_nb"] + "\n"
context += "Case date: " + doc.metadata["case_date"] + "\n"
context += "Case url: " + doc.metadata["case_url"] + "\n"
context += "Case chunk: " + doc.page_content + "\n"
message = f"""
The user is asking for information about the following: {message}.
Answer him in his own language using the information from the following Swiss federal jurisprudence cases:
{context}
Please mention your sources in your answer, including the urls
"""
print(message)
# for val in history:
# if val[0]:
# messages.append({"role": "user", "content": val[0]})
# if val[1]:
# messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=8000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Score Threshold"),
],
description="# 📜 ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
demo.launch(debug=True) |