Spaces:
Running
Running
import os | |
import gradio as gr | |
import requests | |
import json | |
import logging | |
import google.generativeai as genai | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# API Keys configuration | |
COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
if not all([COHERE_API_KEY, MISTRAL_API_KEY, GEMINI_API_KEY]): | |
raise ValueError("Missing required API keys in environment variables") | |
# Configure Gemini | |
genai.configure(api_key=GEMINI_API_KEY) | |
# API endpoints configuration | |
COHERE_API_URL = "https://api.cohere.ai/v1/chat" | |
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" | |
VECTOR_API_URL = "https://sendthat.cc" | |
HISTORY_INDEX = "history" | |
# Model configurations | |
MODELS = { | |
"Cohere": { | |
"name": "command-r-plus-08-2024", | |
"api_url": COHERE_API_URL, | |
"api_key": COHERE_API_KEY | |
}, | |
"Mistral": { | |
"name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46", | |
"api_url": MISTRAL_API_URL, | |
"api_key": MISTRAL_API_KEY | |
}, | |
"Gemini": { | |
"name": "gemini-1.5-pro", | |
"model": genai.GenerativeModel('gemini-1.5-pro'), | |
"api_key": GEMINI_API_KEY | |
} | |
} | |
def search_document(query, k): | |
try: | |
url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}" | |
payload = {"text": query, "k": k} | |
headers = {"Content-Type": "application/json"} | |
response = requests.post(url, json=payload, headers=headers) | |
response.raise_for_status() | |
return response.json(), "", k | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Error in search: {e}") | |
return {"error": str(e)}, query, k | |
def generate_answer_cohere(question, context, citations): | |
headers = { | |
"Authorization": f"Bearer {MODELS['Cohere']['api_key']}", | |
"Content-Type": "application/json" | |
} | |
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:" | |
payload = { | |
"message": prompt, | |
"model": MODELS['Cohere']['name'], | |
"preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.", | |
"chat_history": [] | |
} | |
try: | |
response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload) | |
response.raise_for_status() | |
answer = response.json()['text'] | |
answer += "\n\nSources:" | |
for i, citation in enumerate(citations, 1): | |
answer += f"\n[{i}] {citation}" | |
return answer | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Error in generate_answer_cohere: {e}") | |
return f"An error occurred: {str(e)}" | |
def generate_answer_mistral(question, context, citations): | |
headers = { | |
"Authorization": f"Bearer {MODELS['Mistral']['api_key']}", | |
"Content-Type": "application/json", | |
"Accept": "application/json" | |
} | |
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context and any pre-trained knowledge. Include citations as [1], [2], etc.:" | |
payload = { | |
"model": MODELS['Mistral']['name'], | |
"messages": [ | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
} | |
try: | |
response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload) | |
response.raise_for_status() | |
answer = response.json()['choices'][0]['message']['content'] | |
answer += "\n\nSources:" | |
for i, citation in enumerate(citations, 1): | |
answer += f"\n[{i}] {citation}" | |
return answer | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Error in generate_answer_mistral: {e}") | |
return f"An error occurred: {str(e)}" | |
def generate_answer_gemini(question, context, citations): | |
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:" | |
try: | |
model = MODELS['Gemini']['model'] | |
response = model.generate_content( | |
prompt, | |
generation_config=genai.types.GenerationConfig( | |
temperature=1.0, | |
top_k=40, | |
top_p=0.95, | |
max_output_tokens=8192, | |
) | |
) | |
answer = response.text | |
answer += "\n\nSources:" | |
for i, citation in enumerate(citations, 1): | |
answer += f"\n[{i}] {citation}" | |
return answer | |
except Exception as e: | |
logging.error(f"Error in generate_answer_gemini: {e}") | |
return f"An error occurred: {str(e)}" | |
def answer_question(question, model_choice, k=3): | |
# Search the vector database | |
search_results, _, _ = search_document(question, k) | |
# Extract and combine the retrieved contexts | |
if "results" in search_results: | |
contexts = [] | |
citations = [] | |
for item in search_results['results']: | |
contexts.append(item['metadata']['content']) | |
citations.append(f"{item['metadata'].get('title', 'Unknown Source')} - {item['metadata'].get('source', 'No source provided')}") | |
combined_context = " ".join(contexts) | |
else: | |
logging.error(f"Error in database search or no results found: {search_results}") | |
combined_context = "" | |
citations = [] | |
# Generate answer using the selected model | |
if model_choice == "Cohere": | |
return generate_answer_cohere(question, combined_context, citations) | |
elif model_choice == "Mistral": | |
return generate_answer_mistral(question, combined_context, citations) | |
else: | |
return generate_answer_gemini(question, combined_context, citations) | |
def chatbot(message, history, model_choice): | |
response = answer_question(message, model_choice) | |
return response | |
# Example questions with default model choice | |
EXAMPLE_QUESTIONS = [ | |
["Why was Anne Hutchinson banished from Massachusetts?", "Cohere"], | |
["What were the major causes of World War I?", "Mistral"], | |
["Who was the first President of the United States?", "Gemini"], | |
["What was the significance of the Industrial Revolution?", "Cohere"] | |
] | |
# Create Gradio interface | |
with gr.Blocks(theme="soft") as iface: | |
gr.Markdown("# History Chatbot") | |
gr.Markdown("Ask me anything about history, and I'll provide answers with citations!") | |
with gr.Row(): | |
model_choice = gr.Radio( | |
choices=["Cohere", "Mistral", "Gemini"], | |
value="Cohere", | |
label="Choose LLM Model", | |
info="Select which AI model to use for generating responses" | |
) | |
chatbot_interface = gr.ChatInterface( | |
fn=lambda message, history, model: chatbot(message, history, model), | |
additional_inputs=[model_choice], | |
chatbot=gr.Chatbot(height=300), | |
textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7), | |
examples=EXAMPLE_QUESTIONS, | |
cache_examples=False, | |
retry_btn=None, | |
undo_btn="Delete Previous", | |
clear_btn="Clear", | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch() |