ambrosfitz commited on
Commit
20fac3a
·
verified ·
1 Parent(s): 67fe103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -42
app.py CHANGED
@@ -8,21 +8,36 @@ from dotenv import load_dotenv
8
  # Load environment variables
9
  load_dotenv()
10
 
11
- # Cohere API configuration
12
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
13
- if not COHERE_API_KEY:
14
- raise ValueError("COHERE_API_KEY not found in environment variables")
15
 
16
- COHERE_API_URL = "https://api.cohere.ai/v1/chat"
17
- MODEL_NAME = "command-r-08-2024"
18
 
19
- # Vector database configuration
20
- API_URL = "https://sendthat.cc"
 
 
21
  HISTORY_INDEX = "history"
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def search_document(query, k):
24
  try:
25
- url = f"{API_URL}/search/{HISTORY_INDEX}"
26
  payload = {"text": query, "k": k}
27
  headers = {"Content-Type": "application/json"}
28
  response = requests.post(url, json=payload, headers=headers)
@@ -32,37 +47,69 @@ def search_document(query, k):
32
  logging.error(f"Error in search: {e}")
33
  return {"error": str(e)}, query, k
34
 
35
- def generate_answer(question, context, citations):
36
- prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. At the end of your answer, provide citations for the sources you used, referencing them as [1], [2], etc.:"
37
-
38
  headers = {
39
- "Authorization": f"Bearer {COHERE_API_KEY}",
40
  "Content-Type": "application/json"
41
  }
42
 
 
 
43
  payload = {
44
  "message": prompt,
45
- "model": MODEL_NAME,
46
- "preamble": "You are an AI-assistant chatbot. You are trained to assist users by providing thorough and helpful responses to their queries based on the given context. Always include citations at the end of your answer.",
47
- "chat_history": [] # You can add chat history here if needed
48
  }
49
 
50
  try:
51
- response = requests.post(COHERE_API_URL, headers=headers, json=payload)
52
  response.raise_for_status()
53
  answer = response.json()['text']
54
 
55
- # Append citations to the answer
56
  answer += "\n\nSources:"
57
  for i, citation in enumerate(citations, 1):
58
  answer += f"\n[{i}] {citation}"
59
 
60
  return answer
61
  except requests.exceptions.RequestException as e:
62
- logging.error(f"Error in generate_answer: {e}")
63
  return f"An error occurred: {str(e)}"
64
 
65
- def answer_question(question, k=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Search the vector database
67
  search_results, _, _ = search_document(question, k)
68
 
@@ -79,33 +126,45 @@ def answer_question(question, k=3):
79
  combined_context = ""
80
  citations = []
81
 
82
- # Generate answer using the Cohere LLM
83
- answer = generate_answer(question, combined_context, citations)
84
- return answer
 
 
85
 
86
- def chatbot(message, history):
87
- response = answer_question(message)
88
  return response
89
 
90
  # Create Gradio interface
91
- iface = gr.ChatInterface(
92
- chatbot,
93
- chatbot=gr.Chatbot(height=300),
94
- textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7),
95
- title="History Chatbot",
96
- description="Ask me anything about history, and I'll provide answers with citations!",
97
- theme="soft",
98
- examples=[
99
- "Why was Anne Hutchinson banished from Massachusetts?",
100
- "What were the major causes of World War I?",
101
- "Who was the first President of the United States?",
102
- "What was the significance of the Industrial Revolution?"
103
- ],
104
- cache_examples=False,
105
- retry_btn=None,
106
- undo_btn="Delete Previous",
107
- clear_btn="Clear",
108
- )
 
 
 
 
 
 
 
 
 
 
109
 
110
  # Launch the app
111
  if __name__ == "__main__":
 
8
  # Load environment variables
9
  load_dotenv()
10
 
11
+ # API Keys configuration
12
  COHERE_API_KEY = os.getenv("COHERE_API_KEY")
13
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
 
14
 
15
+ if not COHERE_API_KEY or not MISTRAL_API_KEY:
16
+ raise ValueError("Missing required API keys in environment variables")
17
 
18
+ # API endpoints configuration
19
+ COHERE_API_URL = "https://api.cohere.ai/v1/chat"
20
+ MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
21
+ VECTOR_API_URL = "https://sendthat.cc"
22
  HISTORY_INDEX = "history"
23
 
24
+ # Model configurations
25
+ MODELS = {
26
+ "Cohere": {
27
+ "name": "command-r-08-2024",
28
+ "api_url": COHERE_API_URL,
29
+ "api_key": COHERE_API_KEY
30
+ },
31
+ "Mistral": {
32
+ "name": "ft:open-mistral-nemo:ef730d29:20241022:2a0e7d46",
33
+ "api_url": MISTRAL_API_URL,
34
+ "api_key": MISTRAL_API_KEY
35
+ }
36
+ }
37
+
38
  def search_document(query, k):
39
  try:
40
+ url = f"{VECTOR_API_URL}/search/{HISTORY_INDEX}"
41
  payload = {"text": query, "k": k}
42
  headers = {"Content-Type": "application/json"}
43
  response = requests.post(url, json=payload, headers=headers)
 
47
  logging.error(f"Error in search: {e}")
48
  return {"error": str(e)}, query, k
49
 
50
+ def generate_answer_cohere(question, context, citations):
 
 
51
  headers = {
52
+ "Authorization": f"Bearer {MODELS['Cohere']['api_key']}",
53
  "Content-Type": "application/json"
54
  }
55
 
56
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
57
+
58
  payload = {
59
  "message": prompt,
60
+ "model": MODELS['Cohere']['name'],
61
+ "preamble": "You are an AI-assistant chatbot. Provide thorough responses with citations.",
62
+ "chat_history": []
63
  }
64
 
65
  try:
66
+ response = requests.post(MODELS['Cohere']['api_url'], headers=headers, json=payload)
67
  response.raise_for_status()
68
  answer = response.json()['text']
69
 
 
70
  answer += "\n\nSources:"
71
  for i, citation in enumerate(citations, 1):
72
  answer += f"\n[{i}] {citation}"
73
 
74
  return answer
75
  except requests.exceptions.RequestException as e:
76
+ logging.error(f"Error in generate_answer_cohere: {e}")
77
  return f"An error occurred: {str(e)}"
78
 
79
+ def generate_answer_mistral(question, context, citations):
80
+ headers = {
81
+ "Authorization": f"Bearer {MODELS['Mistral']['api_key']}",
82
+ "Content-Type": "application/json",
83
+ "Accept": "application/json"
84
+ }
85
+
86
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer the question based on the given context. Include citations as [1], [2], etc.:"
87
+
88
+ payload = {
89
+ "model": MODELS['Mistral']['name'],
90
+ "messages": [
91
+ {
92
+ "role": "user",
93
+ "content": prompt
94
+ }
95
+ ]
96
+ }
97
+
98
+ try:
99
+ response = requests.post(MODELS['Mistral']['api_url'], headers=headers, json=payload)
100
+ response.raise_for_status()
101
+ answer = response.json()['choices'][0]['message']['content']
102
+
103
+ answer += "\n\nSources:"
104
+ for i, citation in enumerate(citations, 1):
105
+ answer += f"\n[{i}] {citation}"
106
+
107
+ return answer
108
+ except requests.exceptions.RequestException as e:
109
+ logging.error(f"Error in generate_answer_mistral: {e}")
110
+ return f"An error occurred: {str(e)}"
111
+
112
+ def answer_question(question, model_choice, k=3):
113
  # Search the vector database
114
  search_results, _, _ = search_document(question, k)
115
 
 
126
  combined_context = ""
127
  citations = []
128
 
129
+ # Generate answer using the selected model
130
+ if model_choice == "Cohere":
131
+ return generate_answer_cohere(question, combined_context, citations)
132
+ else:
133
+ return generate_answer_mistral(question, combined_context, citations)
134
 
135
+ def chatbot(message, history, model_choice):
136
+ response = answer_question(message, model_choice)
137
  return response
138
 
139
  # Create Gradio interface
140
+ with gr.Blocks(theme="soft") as iface:
141
+ gr.Markdown("# History Chatbot")
142
+ gr.Markdown("Ask me anything about history, and I'll provide answers with citations!")
143
+
144
+ with gr.Row():
145
+ model_choice = gr.Radio(
146
+ choices=["Cohere", "Mistral"],
147
+ value="Cohere",
148
+ label="Choose LLM Model",
149
+ info="Select which AI model to use for generating responses"
150
+ )
151
+
152
+ chatbot_interface = gr.ChatInterface(
153
+ fn=lambda message, history, model: chatbot(message, history, model),
154
+ additional_inputs=[model_choice],
155
+ chatbot=gr.Chatbot(height=300),
156
+ textbox=gr.Textbox(placeholder="Ask a question about history...", container=False, scale=7),
157
+ examples=[
158
+ "Why was Anne Hutchinson banished from Massachusetts?",
159
+ "What were the major causes of World War I?",
160
+ "Who was the first President of the United States?",
161
+ "What was the significance of the Industrial Revolution?"
162
+ ],
163
+ cache_examples=False,
164
+ retry_btn=None,
165
+ undo_btn="Delete Previous",
166
+ clear_btn="Clear",
167
+ )
168
 
169
  # Launch the app
170
  if __name__ == "__main__":