File size: 12,129 Bytes
5092552
ca98093
 
5092552
 
d4557ee
 
 
e292008
ca98093
 
1fa6961
ca98093
 
0c69489
ca98093
 
e292008
ca98093
 
 
 
 
 
 
 
f4505e9
5092552
 
ca98093
 
a35ea13
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4505e9
86c8869
ca98093
 
 
 
 
f0e66e7
f4505e9
 
86c8869
ca98093
 
 
 
 
f0e66e7
1fa6961
f4505e9
86c8869
ca98093
 
 
 
 
f0e66e7
cc467c2
f4505e9
86c8869
ca98093
 
 
 
 
86c8869
f4505e9
 
 
 
86c8869
ca98093
 
 
 
 
f0e66e7
a55679f
f4505e9
ca98093
 
 
 
 
25c1140
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86c8869
ca98093
 
 
 
 
 
 
 
 
25c1140
d4557ee
0f81d99
f4505e9
ca98093
 
 
 
 
25c1140
ca98093
 
 
 
 
 
 
 
 
 
25c1140
ca98093
b1b6e20
ca98093
 
 
 
 
f4505e9
ca98093
 
 
 
 
 
a55679f
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
 
e292008
ca98093
 
 
 
 
e292008
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c69489
ca98093
 
 
 
 
 
 
 
5092552
ca98093
 
 
5092552
ca98093
 
 
 
 
 
 
0c69489
ca98093
 
 
 
 
 
 
 
5092552
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
a55679f
ca98093
 
 
 
 
 
 
 
 
5092552
ca98093
 
 
 
 
5092552
ca98093
5092552
ca98093
 
5092552
ca98093
0f81d99
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72c7dbb
ca98093
 
 
5092552
ca98093
 
0c69489
ca98093
 
 
 
5092552
 
ca98093
5092552
ca98093
5092552
 
86c8869
ca98093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""
Enhanced LangGraph Agent with Multi-LLM Support and Proper Question Answering
Combines your original LangGraph structure with enhanced response handling
"""

import os
import time
import random
from dotenv import load_dotenv
from typing import List, Dict, Any, TypedDict, Annotated
import operator

from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.checkpoint.memory import MemorySaver

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client

load_dotenv()

# Enhanced system prompt for better question answering
ENHANCED_SYSTEM_PROMPT = """You are a helpful assistant tasked with answering questions using a set of tools.

CRITICAL INSTRUCTIONS:
1. Read the question carefully and understand what specific information is being asked
2. Use the appropriate tools to find the exact information requested
3. For factual questions, search for current and accurate information
4. For calculations, use the math tools provided
5. Always provide specific, direct answers - never repeat the question as your answer
6. If you cannot find the information, state "Information not available" 
7. Format your final response as: FINAL ANSWER: [your specific answer]

ANSWER FORMAT RULES:
- For numbers: provide just the number without commas or units unless specified
- For names/strings: provide the exact name or term without articles
- For lists: provide comma-separated values
- Be concise and specific in your final answer

Remember: Your job is to ANSWER the question, not repeat it back."""

# ---- Enhanced Tool Definitions ----
@tool
def multiply(a: int, b: int) -> int:
    """Multiply two numbers.
    Args:
        a: first int
        b: second int
    """
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Add two numbers.
    Args:
        a: first int
        b: second int
    """
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    """Subtract two numbers.
    Args:
        a: first int
        b: second int
    """
    return a - b

@tool
def divide(a: int, b: int) -> float:
    """Divide two numbers.
    Args:
        a: first int
        b: second int
    """
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulus(a: int, b: int) -> int:
    """Get the modulus of two numbers.
    Args:
        a: first int
        b: second int
    """
    return a % b

@tool
def wiki_search(query: str) -> str:
    """Search Wikipedia for a query and return maximum 2 results.
    Args:
        query: The search query.
    """
    try:
        time.sleep(random.uniform(0.5, 1.0))  # Rate limiting
        search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
        if not search_docs:
            return "No Wikipedia results found"
        
        formatted_search_docs = "\n\n---\n\n".join([
            f'<Document source="{doc.metadata.get("source", "Wikipedia")}" title="{doc.metadata.get("title", "")}">\n{doc.page_content[:1500]}\n</Document>'
            for doc in search_docs
        ])
        return formatted_search_docs
    except Exception as e:
        return f"Wikipedia search failed: {e}"

@tool
def web_search(query: str) -> str:
    """Search Tavily for a query and return maximum 3 results.
    Args:
        query: The search query.
    """
    try:
        time.sleep(random.uniform(0.7, 1.2))  # Rate limiting
        search_tool = TavilySearchResults(max_results=3)
        search_docs = search_tool.invoke({"query": query})
        if not search_docs:
            return "No web search results found"
        
        formatted_search_docs = "\n\n---\n\n".join([
            f'<Document source="{doc.get("url", "")}">\n{doc.get("content", "")[:1200]}\n</Document>'
            for doc in search_docs
        ])
        return formatted_search_docs
    except Exception as e:
        return f"Web search failed: {e}"

@tool
def arxiv_search(query: str) -> str:
    """Search Arxiv for a query and return maximum 3 results.
    Args:
        query: The search query.
    """
    try:
        time.sleep(random.uniform(0.5, 1.0))  # Rate limiting
        search_docs = ArxivLoader(query=query, load_max_docs=3).load()
        if not search_docs:
            return "No ArXiv results found"
        
        formatted_search_docs = "\n\n---\n\n".join([
            f'<Document source="{doc.metadata.get("source", "ArXiv")}" title="{doc.metadata.get("title", "")}">\n{doc.page_content[:1000]}\n</Document>'
            for doc in search_docs
        ])
        return formatted_search_docs
    except Exception as e:
        return f"ArXiv search failed: {e}"

# Initialize tools list
tools = [
    multiply, add, subtract, divide, modulus,
    wiki_search, web_search, arxiv_search
]

# Enhanced State for better tracking
class EnhancedState(MessagesState):
    """Enhanced state with additional tracking"""
    query: str = ""
    tools_used: List[str] = []
    search_results: str = ""

def build_graph(provider: str = "groq"):
    """Build the enhanced graph with proper error handling and response formatting"""
    
    # Initialize LLM based on provider
    if provider == "google":
        llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
    elif provider == "groq":
        llm = ChatGroq(model="llama3-70b-8192", temperature=0)  # Using more reliable model
    elif provider == "huggingface":
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
                temperature=0,
            ),
        )
    else:
        raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
    
    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    # Initialize vector store if available
    vector_store = None
    try:
        if os.getenv("SUPABASE_URL") and os.getenv("SUPABASE_SERVICE_KEY"):
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
            supabase: Client = create_client(
                os.environ.get("SUPABASE_URL"), 
                os.environ.get("SUPABASE_SERVICE_KEY")
            )
            vector_store = SupabaseVectorStore(
                client=supabase,
                embedding=embeddings,
                table_name="documents",
                query_name="match_documents_langchain",
            )
    except Exception as e:
        print(f"Vector store initialization failed: {e}")

    def retriever(state: MessagesState):
        """Enhanced retriever node with fallback"""
        messages = state["messages"]
        query = messages[-1].content if messages else ""
        
        # Try to get similar questions from vector store
        similar_context = ""
        if vector_store:
            try:
                similar_questions = vector_store.similarity_search(query, k=1)
                if similar_questions:
                    similar_context = f"\n\nSimilar example for reference:\n{similar_questions[0].page_content}"
            except Exception as e:
                print(f"Vector search failed: {e}")
        
        # Enhanced system message with context
        enhanced_prompt = ENHANCED_SYSTEM_PROMPT + similar_context
        sys_msg = SystemMessage(content=enhanced_prompt)
        
        return {"messages": [sys_msg] + messages}

    def assistant(state: MessagesState):
        """Enhanced assistant node with better response handling"""
        try:
            response = llm_with_tools.invoke(state["messages"])
            
            # Ensure response is properly formatted
            if hasattr(response, 'content'):
                content = response.content
                
                # Check if this is just repeating the question
                original_query = state["messages"][-1].content if state["messages"] else ""
                if content.strip() == original_query.strip():
                    # Force a better response
                    enhanced_messages = state["messages"] + [
                        HumanMessage(content=f"Please provide a specific answer to this question, do not repeat the question: {original_query}")
                    ]
                    response = llm_with_tools.invoke(enhanced_messages)
            
            return {"messages": [response]}
        except Exception as e:
            error_response = AIMessage(content=f"Error processing request: {e}")
            return {"messages": [error_response]}

    def format_final_answer(state: MessagesState):
        """Format the final answer properly"""
        messages = state["messages"]
        if not messages:
            return {"messages": [AIMessage(content="FINAL ANSWER: Information not available")]}
        
        last_message = messages[-1]
        if hasattr(last_message, 'content'):
            content = last_message.content
            
            # Ensure proper formatting
            if "FINAL ANSWER:" not in content:
                # Extract the key information and format it
                if content.strip():
                    formatted_content = f"FINAL ANSWER: {content.strip()}"
                else:
                    formatted_content = "FINAL ANSWER: Information not available"
                
                formatted_message = AIMessage(content=formatted_content)
                return {"messages": messages[:-1] + [formatted_message]}
        
        return {"messages": messages}

    # Build the graph
    builder = StateGraph(MessagesState)
    
    # Add nodes
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_node("formatter", format_final_answer)
    
    # Add edges
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
        {
            "tools": "tools",
            "__end__": "formatter"
        }
    )
    builder.add_edge("tools", "assistant")
    builder.add_edge("formatter", END)

    # Compile graph with checkpointer
    return builder.compile(checkpointer=MemorySaver())

# Test function
def test_agent():
    """Test the agent with sample questions"""
    graph = build_graph(provider="groq")
    
    test_questions = [
        "How many studio albums were published by Mercedes Sosa between 2000 and 2009?",
        "What is 25 multiplied by 17?",
        "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2004?"
    ]
    
    for question in test_questions:
        print(f"\nQuestion: {question}")
        print("-" * 60)
        
        try:
            messages = [HumanMessage(content=question)]
            config = {"configurable": {"thread_id": f"test_{hash(question)}"}}
            result = graph.invoke({"messages": messages}, config)
            
            if result and "messages" in result:
                final_message = result["messages"][-1]
                if hasattr(final_message, 'content'):
                    print(f"Answer: {final_message.content}")
                else:
                    print(f"Answer: {final_message}")
            else:
                print("Answer: No response generated")
        except Exception as e:
            print(f"Error: {e}")
        
        print()

if __name__ == "__main__":
    # Run tests
    test_agent()