Shreyas094
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -29,6 +29,7 @@ import os
|
|
29 |
from mistralai import Mistral
|
30 |
from dotenv import load_dotenv
|
31 |
import re
|
|
|
32 |
|
33 |
# Automatically get the current year
|
34 |
current_year = datetime.datetime.now().year
|
@@ -65,6 +66,94 @@ mistral_client = Mistral(api_key=MISTRAL_API_KEY)
|
|
65 |
similarity_model = SentenceTransformer('all-MiniLM-L6-v2')
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
# Set up a session with retry mechanism
|
69 |
def requests_retry_session(
|
70 |
retries=0,
|
@@ -609,24 +698,29 @@ def search_and_scrape(query, chat_history, num_results=5, max_chars=3000, time_r
|
|
609 |
logger.error(f"Unexpected error in search_and_scrape: {e}")
|
610 |
return f"An unexpected error occurred during the search and scrape process: {e}"
|
611 |
|
612 |
-
def chat_function(message, history, num_results, max_chars, time_range, language, category, engines, safesearch, method, llm_temperature, model, use_pydf2):
|
613 |
chat_history = "\n".join([f"{role}: {msg}" for role, msg in history])
|
614 |
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
|
|
|
|
|
|
|
|
|
|
630 |
|
631 |
yield response
|
632 |
|
|
|
29 |
from mistralai import Mistral
|
30 |
from dotenv import load_dotenv
|
31 |
import re
|
32 |
+
from typing import List, Tuple
|
33 |
|
34 |
# Automatically get the current year
|
35 |
current_year = datetime.datetime.now().year
|
|
|
66 |
similarity_model = SentenceTransformer('all-MiniLM-L6-v2')
|
67 |
|
68 |
|
69 |
+
def determine_query_type(query: str, chat_history: str, llm_client) -> str:
|
70 |
+
system_prompt = """You are an intelligent agent tasked with determining whether a user query requires a web search or can be answered using the AI's existing knowledge base. Your task is to analyze the query and decide on the appropriate action.
|
71 |
+
|
72 |
+
Instructions:
|
73 |
+
1. If the query is a general conversation starter, greeting, or can be answered without real-time information, classify it as "knowledge_base".
|
74 |
+
2. If the query requires up-to-date information, news, or specific data that might change over time, classify it as "web_search".
|
75 |
+
3. Consider the chat history when making your decision.
|
76 |
+
4. Respond with ONLY "knowledge_base" or "web_search".
|
77 |
+
|
78 |
+
Examples:
|
79 |
+
- "Hi, how are you?" -> "knowledge_base"
|
80 |
+
- "What's the latest news in the US?" -> "web_search"
|
81 |
+
- "Can you explain quantum computing?" -> "knowledge_base"
|
82 |
+
- "What are the current stock prices for Apple?" -> "web_search"
|
83 |
+
"""
|
84 |
+
|
85 |
+
user_prompt = f"""
|
86 |
+
Chat history:
|
87 |
+
{chat_history}
|
88 |
+
|
89 |
+
Current query: {query}
|
90 |
+
|
91 |
+
Determine if this query requires a web search or can be answered from the knowledge base.
|
92 |
+
"""
|
93 |
+
|
94 |
+
messages = [
|
95 |
+
{"role": "system", "content": system_prompt},
|
96 |
+
{"role": "user", "content": user_prompt}
|
97 |
+
]
|
98 |
+
|
99 |
+
try:
|
100 |
+
response = llm_client.chat_completion(
|
101 |
+
messages=messages,
|
102 |
+
max_tokens=10,
|
103 |
+
temperature=0.2
|
104 |
+
)
|
105 |
+
decision = response.choices[0].message.content.strip().lower()
|
106 |
+
return "web_search" if decision == "web_search" else "knowledge_base"
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(f"Error determining query type: {e}")
|
109 |
+
return "web_search" # Default to web search if there's an error
|
110 |
+
|
111 |
+
def generate_ai_response(query: str, chat_history: str, llm_client, model: str) -> str:
|
112 |
+
system_prompt = """You are a helpful AI assistant. Provide a concise and informative response to the user's query based on your existing knowledge. Do not make up information or claim to have real-time data."""
|
113 |
+
|
114 |
+
user_prompt = f"""
|
115 |
+
Chat history:
|
116 |
+
{chat_history}
|
117 |
+
|
118 |
+
Current query: {query}
|
119 |
+
|
120 |
+
Please provide a response to the query.
|
121 |
+
"""
|
122 |
+
|
123 |
+
messages = [
|
124 |
+
{"role": "system", "content": system_prompt},
|
125 |
+
{"role": "user", "content": user_prompt}
|
126 |
+
]
|
127 |
+
|
128 |
+
try:
|
129 |
+
if model == "groq":
|
130 |
+
response = groq_client.chat.completions.create(
|
131 |
+
messages=messages,
|
132 |
+
model="llama-3.1-70b-instant",
|
133 |
+
max_tokens=500,
|
134 |
+
temperature=0.7
|
135 |
+
)
|
136 |
+
return response.choices[0].message.content.strip()
|
137 |
+
elif model == "mistral":
|
138 |
+
response = mistral_client.chat.complete(
|
139 |
+
model="open-mistral-nemo",
|
140 |
+
messages=messages,
|
141 |
+
max_tokens=500,
|
142 |
+
temperature=0.7
|
143 |
+
)
|
144 |
+
return response.choices[0].message.content.strip()
|
145 |
+
else: # huggingface
|
146 |
+
response = llm_client.chat_completion(
|
147 |
+
messages=messages,
|
148 |
+
max_tokens=500,
|
149 |
+
temperature=0.7
|
150 |
+
)
|
151 |
+
return response.choices[0].message.content.strip()
|
152 |
+
except Exception as e:
|
153 |
+
logger.error(f"Error generating AI response: {e}")
|
154 |
+
return "I apologize, but I'm having trouble generating a response at the moment. Please try again later."
|
155 |
+
|
156 |
+
|
157 |
# Set up a session with retry mechanism
|
158 |
def requests_retry_session(
|
159 |
retries=0,
|
|
|
698 |
logger.error(f"Unexpected error in search_and_scrape: {e}")
|
699 |
return f"An unexpected error occurred during the search and scrape process: {e}"
|
700 |
|
701 |
+
def chat_function(message: str, history: List[Tuple[str, str]], num_results: int, max_chars: int, time_range: str, language: str, category: str, engines: List[str], safesearch: int, method: str, llm_temperature: float, model: str, use_pydf2: bool):
|
702 |
chat_history = "\n".join([f"{role}: {msg}" for role, msg in history])
|
703 |
|
704 |
+
query_type = determine_query_type(message, chat_history, client)
|
705 |
+
|
706 |
+
if query_type == "knowledge_base":
|
707 |
+
response = generate_ai_response(message, chat_history, client, model)
|
708 |
+
else: # web_search
|
709 |
+
response = search_and_scrape(
|
710 |
+
query=message,
|
711 |
+
chat_history=chat_history,
|
712 |
+
num_results=num_results,
|
713 |
+
max_chars=max_chars,
|
714 |
+
time_range=time_range,
|
715 |
+
language=language,
|
716 |
+
category=category,
|
717 |
+
engines=engines,
|
718 |
+
safesearch=safesearch,
|
719 |
+
method=method,
|
720 |
+
llm_temperature=llm_temperature,
|
721 |
+
model=model,
|
722 |
+
use_pydf2=use_pydf2
|
723 |
+
)
|
724 |
|
725 |
yield response
|
726 |
|