chatbot-g-rag / app.py
ionosphere's picture
Setup
68ed2d8
raw
history blame
3.76 kB
import gradio as gr
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
import requests
import numpy as np
import faiss
import os
from llama_index import VectorStoreIndex, SimpleDirectoryReader, SummaryIndex
from llama_index.readers import SimpleWebPageReader
from llama_index.llms import MistralAI
from llama_index.embeddings import MistralAIEmbedding
from llama_index import ServiceContext
from llama_index.query_engine import RetrieverQueryEngine
title = "Gaia Mistral Chat Demo with RAG"
description = "Exemple d'assistant avec Gradio et Mistral AI via son API"
placeholder = "Posez moi une question sur l'agriculture"
placeholder_url = "Donner moi une url qui va servir de contexte agricole complémentaire"
examples = ["Comment fait on pour produire du maïs ?", "Rédige moi une lettre pour faire un stage dans une exploitation agricole", "Comment reprendre une exploitation agricole ?"]
api_key = os.environ.get("MISTRAL_API_KEY")
client = MistralClient(api_key=api_key)
chat_model = 'mistral-small'
llm = MistralAI(api_key=api_key,model="mistral-medium")
embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=api_key)
service_context = ServiceContext.from_defaults(chunk_size=512, llm=llm, embed_model=embed_model)
# build a vector database with documents
def setup_db_with_url(url):
global query_engine
documents = SimpleWebPageReader(html_to_text=True).load_data([url])
# insert in DB
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
query_engine = index.as_query_engine(similarity_top_k=2)
# get document source
# response = requests.get(url)
# text = response.text
# print(text)
# chunk_size = 512
# split in chunks
# chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
# embbed in Mistral to have vectors
# text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks])
# print(text_embeddings)
# d = text_embeddings.shape[1]
documents
# vector data with Mistral
def get_text_embedding(input):
embeddings_batch_response = client.embeddings(
model="mistral-embed",
input=input
)
return embeddings_batch_response.data[0].embedding
# build a prompt
def build_prompt(user_input):
retrieved_chunk = query_engine.query(user_input)
prompt = f"""
Context information is below.
---------------------
{retrieved_chunk}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {user_input}
Answer:
"""
def chat_with_mistral(user_input, history):
prompt = build_prompt(user_input)
messages = [ChatMessage(role="user", content=prompt)]
chat_response = client.chat(model=chat_model, messages=messages)
mistral_content = chat_response.choices[0].message.content
histories = history + [(mistral_content, None)]
return [mistral_content, histories]
with gr.Blocks() as iface:
with gr.Row():
gr.Markdown("#Mixtral-8x7B Playground Space!")
with gr.Row():
url_msg = gr.Textbox(placeholder=placeholder_url, container=False, scale=7)
url_btn = gr.Button(value="🔄", interactive=True)
with gr.Row():
url_return = gr.Textbox(value='', container=False, scale=7)
url_btn.click(setup_db_with_url, url_msg, url_return)
with gr.Row():
chatbot=gr.Chatbot(height=300)
with gr.Row():
msg = gr.Textbox(placeholder=placeholder, container=False, scale=7)
msg_btn = gr.Button("Envoyer")
msg_btn.click(chat_with_mistral, [msg, chatbot], [msg, chatbot] )
iface.title = title
iface.launch(share=True)