Spaces:
Sleeping
Sleeping
import os | |
__import__('pysqlite3') | |
import sys | |
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | |
from dotenv import load_dotenv | |
import json | |
import gradio as gr | |
import chromadb | |
from llama_index.core import ( | |
VectorStoreIndex, | |
StorageContext, | |
Settings, | |
download_loader, | |
) | |
from llama_index.llms.mistralai import MistralAI | |
from llama_index.embeddings.mistralai import MistralAIEmbedding | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
load_dotenv() | |
title = "AgreenDefi Gaia 8x22b PDF Demo" | |
description = "Example of an assistant with Gradio, RAG from PDF documents and Mistral AI via its API" | |
placeholder = ( | |
"Vous pouvez me posez une question sur ce contexte, appuyer sur Entrée pour valider" | |
) | |
llm_model = "open-mixtral-8x22b" | |
env_api_key = os.environ.get("MISTRAL_API_KEY") | |
query_engine = None | |
# Define LLMs | |
llm = MistralAI(api_key=env_api_key, model=llm_model) | |
embed_model = MistralAIEmbedding(model_name="mistral-embed", api_key=env_api_key) | |
# create client and a new collection | |
db = chromadb.PersistentClient(path="./chroma_db") | |
chroma_collection = db.get_or_create_collection("quickstart") | |
# set up ChromaVectorStore and load in data | |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
Settings.llm = llm | |
Settings.embed_model = embed_model | |
Settings.chunk_size = 1024 | |
PDFReader = download_loader("PDFReader") | |
loader = PDFReader() | |
index = VectorStoreIndex( | |
[], storage_context=storage_context | |
) | |
query_engine = index.as_query_engine(similarity_top_k=5) | |
def get_documents_in_db(): | |
print("Fetching documents in DB") | |
docs = [] | |
for item in chroma_collection.get(include=["metadatas"])["metadatas"]: | |
docs.append(json.loads(item["_node_content"])["metadata"]["file_name"]) | |
docs = list(set(docs)) | |
print(f"Found {len(docs)} documents") | |
out = "**List of files in db:**\n" | |
for d in docs: | |
out += " - " + d + "\n" | |
return out | |
def empty_db(): | |
ids = chroma_collection.get()["ids"] | |
chroma_collection.delete(ids) | |
return get_documents_in_db() | |
def load_file(files): | |
for file in files: | |
documents = loader.load_data(file=file) | |
for doc in documents: | |
index.insert(doc) | |
return ( | |
gr.Textbox(visible=False), | |
gr.Textbox(value=f"Document encoded ! You can ask questions", visible=True), | |
get_documents_in_db(), | |
) | |
def load_document(input_file): | |
file_name = input_file.name.split("/")[-1] | |
return gr.Textbox(value=f"Document loaded: {file_name}", visible=True) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" # Bienvenue sur la démo AgreenDefi PDF | |
Ajouter un fichier avant de poser une question sur le tchat. | |
Cette démo vous permet d'interagir entre des fichiers PDF et Mistral AI via son API. | |
Mistral va répondre à vos questions par rapport au document. | |
*The files will stay in the database unless there is 48h of inactivty or you re-build the space.* | |
""" | |
) | |
gr.Markdown(""" ### 1 / Préparer les données """) | |
with gr.Row(): | |
with gr.Column(): | |
input_file = gr.File( | |
label="Charger des fichiers pdf", | |
file_types=[".pdf"], | |
file_count="multiple", | |
type="filepath", | |
interactive=True, | |
) | |
file_msg = gr.Textbox( | |
label="Loaded documents:", container=False, visible=False | |
) | |
input_file.upload( | |
fn=load_document, | |
inputs=[ | |
input_file, | |
], | |
outputs=[file_msg], | |
concurrency_limit=20, | |
) | |
help_msg = gr.Markdown( | |
value="Quan le document est chargé, Appuyer sur Encode pour l'ajouter dans la base de données." | |
) | |
file_btn = gr.Button(value="Encoder les fichiers ✅", interactive=True) | |
btn_msg = gr.Textbox(container=False, visible=False) | |
with gr.Row(): | |
db_list = gr.Markdown(value=get_documents_in_db) | |
delete_btn = gr.Button(value="Vider la base 🗑️", interactive=True, scale=0) | |
file_btn.click( | |
load_file, | |
inputs=[input_file], | |
outputs=[file_msg, btn_msg, db_list], | |
show_progress="full", | |
) | |
delete_btn.click(empty_db, outputs=[db_list], show_progress="minimal") | |
gr.Markdown(""" ### 2 / Poser une question selon le contexte """) | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder=placeholder) | |
clear = gr.ClearButton([msg, chatbot]) | |
def respond(message, chat_history): | |
response = query_engine.query(message) | |
chat_history.append((message, str(response))) | |
return chat_history | |
msg.submit(respond, [msg, chatbot], [chatbot]) | |
demo.title = title | |
demo.launch() | |