Spaces:
Runtime error
Runtime error
import gradio as gr | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
import requests | |
import numpy as np | |
import faiss | |
import os | |
from llama_index import VectorStoreIndex, SimpleDirectoryReader, SummaryIndex | |
from llama_index.readers import SimpleWebPageReader | |
from llama_index.llms import MistralAI | |
from llama_index.embeddings import MistralAIEmbedding | |
from llama_index import ServiceContext | |
from llama_index.query_engine import RetrieverQueryEngine | |
title = "Gaia Mistral Chat Demo with RAG" | |
description = "Exemple d'assistant avec Gradio et Mistral AI via son API" | |
placeholder = "Posez moi une question sur l'agriculture" | |
placeholder_url = "Donner moi une url qui va servir de contexte agricole complémentaire" | |
examples = ["Comment fait on pour produire du maïs ?", "Rédige moi une lettre pour faire un stage dans une exploitation agricole", "Comment reprendre une exploitation agricole ?"] | |
api_key = os.environ.get("MISTRAL_API_KEY") | |
client = MistralClient(api_key=api_key) | |
chat_model = 'mistral-small' | |
llm = MistralAI(api_key=api_key,model="mistral-medium") | |
embed_model = MistralAIEmbedding(model_name='mistral-embed', api_key=api_key) | |
service_context = ServiceContext.from_defaults(chunk_size=512, llm=llm, embed_model=embed_model) | |
# build a vector database with documents | |
def setup_db_with_url(url): | |
global query_engine | |
documents = SimpleWebPageReader(html_to_text=True).load_data([url]) | |
# insert in DB | |
index = VectorStoreIndex.from_documents(documents, service_context=service_context) | |
query_engine = index.as_query_engine(similarity_top_k=2) | |
# get document source | |
# response = requests.get(url) | |
# text = response.text | |
# print(text) | |
# chunk_size = 512 | |
# split in chunks | |
# chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
# embbed in Mistral to have vectors | |
# text_embeddings = np.array([get_text_embedding(chunk) for chunk in chunks]) | |
# print(text_embeddings) | |
# d = text_embeddings.shape[1] | |
documents | |
# vector data with Mistral | |
def get_text_embedding(input): | |
embeddings_batch_response = client.embeddings( | |
model="mistral-embed", | |
input=input | |
) | |
return embeddings_batch_response.data[0].embedding | |
# build a prompt | |
def build_prompt(user_input): | |
retrieved_chunk = query_engine.query(user_input) | |
prompt = f""" | |
Context information is below. | |
--------------------- | |
{retrieved_chunk} | |
--------------------- | |
Given the context information and not prior knowledge, answer the query. | |
Query: {user_input} | |
Answer: | |
""" | |
def chat_with_mistral(user_input, history): | |
prompt = build_prompt(user_input) | |
messages = [ChatMessage(role="user", content=prompt)] | |
chat_response = client.chat(model=chat_model, messages=messages) | |
mistral_content = chat_response.choices[0].message.content | |
histories = history + [(mistral_content, None)] | |
return [mistral_content, histories] | |
with gr.Blocks() as iface: | |
with gr.Row(): | |
gr.Markdown("#Mixtral-8x7B Playground Space!") | |
with gr.Row(): | |
url_msg = gr.Textbox(placeholder=placeholder_url, container=False, scale=7) | |
url_btn = gr.Button(value="🔄", interactive=True) | |
with gr.Row(): | |
url_return = gr.Textbox(value='', container=False, scale=7) | |
url_btn.click(setup_db_with_url, url_msg, url_return) | |
with gr.Row(): | |
chatbot=gr.Chatbot(height=300) | |
with gr.Row(): | |
msg = gr.Textbox(placeholder=placeholder, container=False, scale=7) | |
msg_btn = gr.Button("Envoyer") | |
msg_btn.click(chat_with_mistral, [msg, chatbot], [msg, chatbot] ) | |
iface.title = title | |
iface.launch(share=True) |