Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import requests
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import google.generativeai as genai
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# API Keys configuration
|
13 |
+
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
14 |
+
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
|
15 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
16 |
+
|
17 |
+
if not all([COHERE_API_KEY, MISTRAL_API_KEY, GEMINI_API_KEY]):
|
18 |
+
raise ValueError("Missing required API keys in environment variables")
|
19 |
+
|
20 |
+
# Configure Gemini
|
21 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
22 |
+
|
23 |
+
# API endpoints configuration
|
24 |
+
COHERE_API_URL = "https://api.cohere.ai/v1/chat"
|
25 |
+
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
|
26 |
+
VECTOR_API_URL = "https://sendthat.cc"
|
27 |
+
HISTORY_INDEX = "onramps"
|
28 |
+
|
29 |
+
# Model configurations
|
30 |
+
MODELS = {
|
31 |
+
"Cohere": {
|
32 |
+
"name": "command-r-08-2024",
|
33 |
+
"api_url": COHERE_API_URL,
|
34 |
+
"api_key": COHERE_API_KEY
|
35 |
+
},
|
36 |
+
"Mistral": {
|
37 |
+
"name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46",
|
38 |
+
"api_url": MISTRAL_API_URL,
|
39 |
+
"api_key": MISTRAL_API_KEY
|
40 |
+
},
|
41 |
+
"Gemini": {
|
42 |
+
"name": "gemini-1.5-pro",
|
43 |
+
"model": genai.GenerativeModel('gemini-1.5-pro'),
|
44 |
+
"api_key": GEMINI_API_KEY
|
45 |
+
}
|
46 |
+
}
|
47 |
+
|
48 |
+
def search_document(query, k):
|
49 |
+
try:
|
50 |
+
url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}"
|
51 |
+
payload = {"text": query, "k": k}
|
52 |
+
headers = {"Content-Type": "application/json"}
|
53 |
+
response = requests.post(url, json=payload, headers=headers)
|
54 |
+
response.raise_for_status()
|
55 |
+
return response.json(), "", k
|
56 |
+
except requests.exceptions.RequestException as e:
|
57 |
+
logging.error(f"Error in search: {e}")
|
58 |
+
return {"error": str(e)}, query, k
|
59 |
+
|
60 |
+
def generate_answer_cohere(question, context, citations):
|
61 |
+
headers = {
|
62 |
+
"Authorization": f"Bearer {MODELS['Cohere']['api_key']}",
|
63 |
+
"Content-Type": "application/json"
|
64 |
+
}
|
65 |
+
|
66 |
+
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
|
67 |
+
|
68 |
+
payload = {
|
69 |
+
"message": prompt,
|
70 |
+
"model": MODELS['Cohere']['name'],
|
71 |
+
"preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.",
|
72 |
+
"chat_history": []
|
73 |
+
}
|
74 |
+
|
75 |
+
try:
|
76 |
+
response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload)
|
77 |
+
response.raise_for_status()
|
78 |
+
answer = response.json()['text']
|
79 |
+
|
80 |
+
answer += "\n\nSources:"
|
81 |
+
for i, citation in enumerate(citations, 1):
|
82 |
+
answer += f"\n[{i}] {citation}"
|
83 |
+
|
84 |
+
return answer
|
85 |
+
except requests.exceptions.RequestException as e:
|
86 |
+
logging.error(f"Error in generate_answer_cohere: {e}")
|
87 |
+
return f"An error occurred: {str(e)}"
|
88 |
+
|
89 |
+
def generate_answer_mistral(question, context, citations):
|
90 |
+
headers = {
|
91 |
+
"Authorization": f"Bearer {MODELS['Mistral']['api_key']}",
|
92 |
+
"Content-Type": "application/json",
|
93 |
+
"Accept": "application/json"
|
94 |
+
}
|
95 |
+
|
96 |
+
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
|
97 |
+
|
98 |
+
payload = {
|
99 |
+
"model": MODELS['Mistral']['name'],
|
100 |
+
"messages": [
|
101 |
+
{
|
102 |
+
"role": "user",
|
103 |
+
"content": prompt
|
104 |
+
}
|
105 |
+
]
|
106 |
+
}
|
107 |
+
|
108 |
+
try:
|
109 |
+
response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload)
|
110 |
+
response.raise_for_status()
|
111 |
+
answer = response.json()['choices'][0]['message']['content']
|
112 |
+
|
113 |
+
answer += "\n\nSources:"
|
114 |
+
for i, citation in enumerate(citations, 1):
|
115 |
+
answer += f"\n[{i}] {citation}"
|
116 |
+
|
117 |
+
return answer
|
118 |
+
except requests.exceptions.RequestException as e:
|
119 |
+
logging.error(f"Error in generate_answer_mistral: {e}")
|
120 |
+
return f"An error occurred: {str(e)}"
|
121 |
+
|
122 |
+
def generate_answer_gemini(question, context, citations):
|
123 |
+
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
|
124 |
+
|
125 |
+
try:
|
126 |
+
model = MODELS['Gemini']['model']
|
127 |
+
response = model.generate_content(
|
128 |
+
prompt,
|
129 |
+
generation_config=genai.types.GenerationConfig(
|
130 |
+
temperature=1.0,
|
131 |
+
top_k=40,
|
132 |
+
top_p=0.95,
|
133 |
+
max_output_tokens=8192,
|
134 |
+
)
|
135 |
+
)
|
136 |
+
|
137 |
+
answer = response.text
|
138 |
+
|
139 |
+
answer += "\n\nSources:"
|
140 |
+
for i, citation in enumerate(citations, 1):
|
141 |
+
answer += f"\n[{i}] {citation}"
|
142 |
+
|
143 |
+
return answer
|
144 |
+
except Exception as e:
|
145 |
+
logging.error(f"Error in generate_answer_gemini: {e}")
|
146 |
+
return f"An error occurred: {str(e)}"
|
147 |
+
|
148 |
+
def answer_question(question, model_choice, k=3):
|
149 |
+
# Search the vector database
|
150 |
+
search_results, _, _ = search_document(question, k)
|
151 |
+
|
152 |
+
# Extract and combine the retrieved contexts
|
153 |
+
if "results" in search_results:
|
154 |
+
contexts = []
|
155 |
+
citations = []
|
156 |
+
for item in search_results['results']:
|
157 |
+
contexts.append(item['metadata']['content'])
|
158 |
+
citations.append(f"{item['metadata'].get('title', 'Unknown Source')} - {item['metadata'].get('source', 'No source provided')}")
|
159 |
+
combined_context = " ".join(contexts)
|
160 |
+
else:
|
161 |
+
logging.error(f"Error in database search or no results found: {search_results}")
|
162 |
+
combined_context = ""
|
163 |
+
citations = []
|
164 |
+
|
165 |
+
# Generate answer using the selected model
|
166 |
+
if model_choice == "Cohere":
|
167 |
+
return generate_answer_cohere(question, combined_context, citations)
|
168 |
+
elif model_choice == "Mistral":
|
169 |
+
return generate_answer_mistral(question, combined_context, citations)
|
170 |
+
else:
|
171 |
+
return generate_answer_gemini(question, combined_context, citations)
|
172 |
+
|
173 |
+
def chatbot(message, history, model_choice):
|
174 |
+
response = answer_question(message, model_choice)
|
175 |
+
return response
|
176 |
+
|
177 |
+
# Example questions with default model choice
|
178 |
+
EXAMPLE_QUESTIONS = [
|
179 |
+
["Why was Anne Hutchinson banished from Massachusetts?", "Cohere"],
|
180 |
+
["What were the major causes of World War I?", "Mistral"],
|
181 |
+
["Who was the first President of the United States?", "Gemini"],
|
182 |
+
["What was the significance of the Industrial Revolution?", "Cohere"]
|
183 |
+
]
|
184 |
+
|
185 |
+
# Create Gradio interface
|
186 |
+
with gr.Blocks(theme="soft") as iface:
|
187 |
+
gr.Markdown("# History Chatbot")
|
188 |
+
gr.Markdown("Ask me anything about history, and I'll provide answers with citations!")
|
189 |
+
|
190 |
+
with gr.Row():
|
191 |
+
model_choice = gr.Radio(
|
192 |
+
choices=["Cohere", "Mistral", "Gemini"],
|
193 |
+
value="Cohere",
|
194 |
+
label="Choose LLM Model",
|
195 |
+
info="Select which AI model to use for generating responses"
|
196 |
+
)
|
197 |
+
|
198 |
+
chatbot_interface = gr.ChatInterface(
|
199 |
+
fn=lambda message, history, model: chatbot(message, history, model),
|
200 |
+
additional_inputs=[model_choice],
|
201 |
+
chatbot=gr.Chatbot(height=300),
|
202 |
+
textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7),
|
203 |
+
examples=EXAMPLE_QUESTIONS,
|
204 |
+
cache_examples=False,
|
205 |
+
retry_btn=None,
|
206 |
+
undo_btn="Delete Previous",
|
207 |
+
clear_btn="Clear",
|
208 |
+
)
|
209 |
+
|
210 |
+
# Launch the app
|
211 |
+
if __name__ == "__main__":
|
212 |
+
iface.launch()
|