chatbot / app.py
umaiku's picture
Update app.py
1eaeeb8 verified
raw
history blame
4.31 kB
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient, login, snapshot_download
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
import os
import pandas as pd
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
HF_TOKEN=os.getenv('TOKEN')
login(HF_TOKEN)
#model = "meta-llama/Llama-3.2-1B-Instruct"
#model = "google/mt5-small"
model = "mistralai/Mistral-7B-Instruct-v0.3"
client = InferenceClient(model)
folder = snapshot_download(repo_id="umaiku/faiss_index", repo_type="dataset", local_dir=os.getcwd())
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-small")
vector_db = FAISS.load_local("faiss_index_8k", embeddings, allow_dangerous_deserialization=True)
df = pd.read_csv("faiss_index/bger_cedh_db 1954-2024.csv")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
score,
):
messages = [{"role": "system", "content": system_message}]
print(system_message)
retriever = vector_db.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": score})
documents = retriever.invoke(message)
spacer = " \n"
context = ""
print(len(documents))
for doc in documents:
case_text = df[df["case_url"] == doc.metadata["case_url"]].case_text.values[0]
context += "Case number: " + doc.metadata["case_nb"] + spacer
context += "Case source: " + ("Swiss Federal Court" if doc.metadata["case_ref"] == "ATF" else "European Court of Human Rights") + spacer
context += "Case date: " + doc.metadata["case_date"] + spacer
context += "Case url: " + doc.metadata["case_url"] + spacer
context += "Case text: " + doc.page_content + spacer
#context += "Case text: " + case_text[:8000] + spacer
message = f"""
A user is asking you the following question: {message}
Please answer the user in the same language that he used in his question using the following given context not prior knowledge.
Context:
The following case extracts from various Swiss Federal Court and European Court of Human Rights cases have been found to fit the question :
{context}
Task:
Start by summarizing these cases in the user's question's language and reference the sources, including the urls and dates.
If no relevant docs were retrieved or the issue has not been addressed in the context, just say "You can't find enough relevant information".
Don't make up an answer or give irrelevant information not requested by the user .
Instructions:
Always answer the user using the language used in his question: {message}
"""
print(message)
# for val in history:
# if val[0]:
# messages.append({"role": "user", "content": val[0]})
# if val[1]:
# messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are an assistant in Swiss Jurisprudence cases.", label="System message"),
gr.Slider(minimum=1, maximum=24000, value=5000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Score Threshold"),
],
description="# πŸ“œ ALexI: Artificial Legal Intelligence for Swiss Jurisprudence",
)
if __name__ == "__main__":
demo.launch(debug=True)