OR-chatbot / app.py
ambrosfitz's picture
Update app.py
15a0fb3 verified
import os
import gradio as gr
import requests
import json
import logging
import google.generativeai as genai
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# API Keys configuration
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
if not all([COHERE_API_KEY, MISTRAL_API_KEY, GEMINI_API_KEY]):
raise ValueError("Missing required API keys in environment variables")
# Configure Gemini
genai.configure(api_key=GEMINI_API_KEY)
# API endpoints configuration
COHERE_API_URL = "https://api.cohere.ai/v1/chat"
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
VECTOR_API_URL = "https://sendthat.cc"
HISTORY_INDEX = "history"
# Model configurations
MODELS = {
"Cohere": {
"name": "command-r-plus-08-2024",
"api_url": COHERE_API_URL,
"api_key": COHERE_API_KEY
},
"Mistral": {
"name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46",
"api_url": MISTRAL_API_URL,
"api_key": MISTRAL_API_KEY
},
"Gemini": {
"name": "gemini-1.5-pro",
"model": genai.GenerativeModel('gemini-1.5-pro'),
"api_key": GEMINI_API_KEY
}
}
def search_document(query, k):
try:
url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}"
payload = {"text": query, "k": k}
headers = {"Content-Type": "application/json"}
response = requests.post(url, json=payload, headers=headers)
response.raise_for_status()
return response.json(), "", k
except requests.exceptions.RequestException as e:
logging.error(f"Error in search: {e}")
return {"error": str(e)}, query, k
def generate_answer_cohere(question, context, citations):
headers = {
"Authorization": f"Bearer {MODELS['Cohere']['api_key']}",
"Content-Type": "application/json"
}
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
payload = {
"message": prompt,
"model": MODELS['Cohere']['name'],
"preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.",
"chat_history": []
}
try:
response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload)
response.raise_for_status()
answer = response.json()['text']
answer += "\n\nSources:"
for i, citation in enumerate(citations, 1):
answer += f"\n[{i}] {citation}"
return answer
except requests.exceptions.RequestException as e:
logging.error(f"Error in generate_answer_cohere: {e}")
return f"An error occurred: {str(e)}"
def generate_answer_mistral(question, context, citations):
headers = {
"Authorization": f"Bearer {MODELS['Mistral']['api_key']}",
"Content-Type": "application/json",
"Accept": "application/json"
}
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context and any pre-trained knowledge. Include citations as [1], [2], etc.:"
payload = {
"model": MODELS['Mistral']['name'],
"messages": [
{
"role": "user",
"content": prompt
}
]
}
try:
response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload)
response.raise_for_status()
answer = response.json()['choices'][0]['message']['content']
answer += "\n\nSources:"
for i, citation in enumerate(citations, 1):
answer += f"\n[{i}] {citation}"
return answer
except requests.exceptions.RequestException as e:
logging.error(f"Error in generate_answer_mistral: {e}")
return f"An error occurred: {str(e)}"
def generate_answer_gemini(question, context, citations):
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
try:
model = MODELS['Gemini']['model']
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
temperature=1.0,
top_k=40,
top_p=0.95,
max_output_tokens=8192,
)
)
answer = response.text
answer += "\n\nSources:"
for i, citation in enumerate(citations, 1):
answer += f"\n[{i}] {citation}"
return answer
except Exception as e:
logging.error(f"Error in generate_answer_gemini: {e}")
return f"An error occurred: {str(e)}"
def answer_question(question, model_choice, k=3):
# Search the vector database
search_results, _, _ = search_document(question, k)
# Extract and combine the retrieved contexts
if "results" in search_results:
contexts = []
citations = []
for item in search_results['results']:
contexts.append(item['metadata']['content'])
citations.append(f"{item['metadata'].get('title', 'Unknown Source')} - {item['metadata'].get('source', 'No source provided')}")
combined_context = " ".join(contexts)
else:
logging.error(f"Error in database search or no results found: {search_results}")
combined_context = ""
citations = []
# Generate answer using the selected model
if model_choice == "Cohere":
return generate_answer_cohere(question, combined_context, citations)
elif model_choice == "Mistral":
return generate_answer_mistral(question, combined_context, citations)
else:
return generate_answer_gemini(question, combined_context, citations)
def chatbot(message, history, model_choice):
response = answer_question(message, model_choice)
return response
# Example questions with default model choice
EXAMPLE_QUESTIONS = [
["Why was Anne Hutchinson banished from Massachusetts?", "Cohere"],
["What were the major causes of World War I?", "Mistral"],
["Who was the first President of the United States?", "Gemini"],
["What was the significance of the Industrial Revolution?", "Cohere"]
]
# Create Gradio interface
with gr.Blocks(theme="soft") as iface:
gr.Markdown("# History Chatbot")
gr.Markdown("Ask me anything about history, and I'll provide answers with citations!")
with gr.Row():
model_choice = gr.Radio(
choices=["Cohere", "Mistral", "Gemini"],
value="Cohere",
label="Choose LLM Model",
info="Select which AI model to use for generating responses"
)
chatbot_interface = gr.ChatInterface(
fn=lambda message, history, model: chatbot(message, history, model),
additional_inputs=[model_choice],
chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7),
examples=EXAMPLE_QUESTIONS,
cache_examples=False,
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
)
# Launch the app
if __name__ == "__main__":
iface.launch()