ambrosfitz commited on
Commit
8b95db4
·
verified ·
1 Parent(s): dcd3bb8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import requests
4
+ import json
5
+ import logging
6
+ import google.generativeai as genai
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # API Keys configuration
13
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
14
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
15
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
16
+
17
+ if not all([COHERE_API_KEY, MISTRAL_API_KEY, GEMINI_API_KEY]):
18
+ raise ValueError("Missing required API keys in environment variables")
19
+
20
+ # Configure Gemini
21
+ genai.configure(api_key=GEMINI_API_KEY)
22
+
23
+ # API endpoints configuration
24
+ COHERE_API_URL = "https://api.cohere.ai/v1/chat"
25
+ MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
26
+ VECTOR_API_URL = "https://sendthat.cc"
27
+ HISTORY_INDEX = "onramps"
28
+
29
+ # Model configurations
30
+ MODELS = {
31
+ "Cohere": {
32
+ "name": "command-r-08-2024",
33
+ "api_url": COHERE_API_URL,
34
+ "api_key": COHERE_API_KEY
35
+ },
36
+ "Mistral": {
37
+ "name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46",
38
+ "api_url": MISTRAL_API_URL,
39
+ "api_key": MISTRAL_API_KEY
40
+ },
41
+ "Gemini": {
42
+ "name": "gemini-1.5-pro",
43
+ "model": genai.GenerativeModel('gemini-1.5-pro'),
44
+ "api_key": GEMINI_API_KEY
45
+ }
46
+ }
47
+
48
+ def search_document(query, k):
49
+ try:
50
+ url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}"
51
+ payload = {"text": query, "k": k}
52
+ headers = {"Content-Type": "application/json"}
53
+ response = requests.post(url, json=payload, headers=headers)
54
+ response.raise_for_status()
55
+ return response.json(), "", k
56
+ except requests.exceptions.RequestException as e:
57
+ logging.error(f"Error in search: {e}")
58
+ return {"error": str(e)}, query, k
59
+
60
+ def generate_answer_cohere(question, context, citations):
61
+ headers = {
62
+ "Authorization": f"Bearer {MODELS['Cohere']['api_key']}",
63
+ "Content-Type": "application/json"
64
+ }
65
+
66
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
67
+
68
+ payload = {
69
+ "message": prompt,
70
+ "model": MODELS['Cohere']['name'],
71
+ "preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.",
72
+ "chat_history": []
73
+ }
74
+
75
+ try:
76
+ response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload)
77
+ response.raise_for_status()
78
+ answer = response.json()['text']
79
+
80
+ answer += "\n\nSources:"
81
+ for i, citation in enumerate(citations, 1):
82
+ answer += f"\n[{i}] {citation}"
83
+
84
+ return answer
85
+ except requests.exceptions.RequestException as e:
86
+ logging.error(f"Error in generate_answer_cohere: {e}")
87
+ return f"An error occurred: {str(e)}"
88
+
89
+ def generate_answer_mistral(question, context, citations):
90
+ headers = {
91
+ "Authorization": f"Bearer {MODELS['Mistral']['api_key']}",
92
+ "Content-Type": "application/json",
93
+ "Accept": "application/json"
94
+ }
95
+
96
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
97
+
98
+ payload = {
99
+ "model": MODELS['Mistral']['name'],
100
+ "messages": [
101
+ {
102
+ "role": "user",
103
+ "content": prompt
104
+ }
105
+ ]
106
+ }
107
+
108
+ try:
109
+ response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload)
110
+ response.raise_for_status()
111
+ answer = response.json()['choices'][0]['message']['content']
112
+
113
+ answer += "\n\nSources:"
114
+ for i, citation in enumerate(citations, 1):
115
+ answer += f"\n[{i}] {citation}"
116
+
117
+ return answer
118
+ except requests.exceptions.RequestException as e:
119
+ logging.error(f"Error in generate_answer_mistral: {e}")
120
+ return f"An error occurred: {str(e)}"
121
+
122
+ def generate_answer_gemini(question, context, citations):
123
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
124
+
125
+ try:
126
+ model = MODELS['Gemini']['model']
127
+ response = model.generate_content(
128
+ prompt,
129
+ generation_config=genai.types.GenerationConfig(
130
+ temperature=1.0,
131
+ top_k=40,
132
+ top_p=0.95,
133
+ max_output_tokens=8192,
134
+ )
135
+ )
136
+
137
+ answer = response.text
138
+
139
+ answer += "\n\nSources:"
140
+ for i, citation in enumerate(citations, 1):
141
+ answer += f"\n[{i}] {citation}"
142
+
143
+ return answer
144
+ except Exception as e:
145
+ logging.error(f"Error in generate_answer_gemini: {e}")
146
+ return f"An error occurred: {str(e)}"
147
+
148
+ def answer_question(question, model_choice, k=3):
149
+ # Search the vector database
150
+ search_results, _, _ = search_document(question, k)
151
+
152
+ # Extract and combine the retrieved contexts
153
+ if "results" in search_results:
154
+ contexts = []
155
+ citations = []
156
+ for item in search_results['results']:
157
+ contexts.append(item['metadata']['content'])
158
+ citations.append(f"{item['metadata'].get('title', 'Unknown Source')} - {item['metadata'].get('source', 'No source provided')}")
159
+ combined_context = " ".join(contexts)
160
+ else:
161
+ logging.error(f"Error in database search or no results found: {search_results}")
162
+ combined_context = ""
163
+ citations = []
164
+
165
+ # Generate answer using the selected model
166
+ if model_choice == "Cohere":
167
+ return generate_answer_cohere(question, combined_context, citations)
168
+ elif model_choice == "Mistral":
169
+ return generate_answer_mistral(question, combined_context, citations)
170
+ else:
171
+ return generate_answer_gemini(question, combined_context, citations)
172
+
173
+ def chatbot(message, history, model_choice):
174
+ response = answer_question(message, model_choice)
175
+ return response
176
+
177
+ # Example questions with default model choice
178
+ EXAMPLE_QUESTIONS = [
179
+ ["Why was Anne Hutchinson banished from Massachusetts?", "Cohere"],
180
+ ["What were the major causes of World War I?", "Mistral"],
181
+ ["Who was the first President of the United States?", "Gemini"],
182
+ ["What was the significance of the Industrial Revolution?", "Cohere"]
183
+ ]
184
+
185
+ # Create Gradio interface
186
+ with gr.Blocks(theme="soft") as iface:
187
+ gr.Markdown("# History Chatbot")
188
+ gr.Markdown("Ask me anything about history, and I'll provide answers with citations!")
189
+
190
+ with gr.Row():
191
+ model_choice = gr.Radio(
192
+ choices=["Cohere", "Mistral", "Gemini"],
193
+ value="Cohere",
194
+ label="Choose LLM Model",
195
+ info="Select which AI model to use for generating responses"
196
+ )
197
+
198
+ chatbot_interface = gr.ChatInterface(
199
+ fn=lambda message, history, model: chatbot(message, history, model),
200
+ additional_inputs=[model_choice],
201
+ chatbot=gr.Chatbot(height=300),
202
+ textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7),
203
+ examples=EXAMPLE_QUESTIONS,
204
+ cache_examples=False,
205
+ retry_btn=None,
206
+ undo_btn="Delete Previous",
207
+ clear_btn="Clear",
208
+ )
209
+
210
+ # Launch the app
211
+ if __name__ == "__main__":
212
+ iface.launch()