import os import gradio as gr from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_community.vectorstores import SupabaseVectorStore from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.tools import tool from supabase import create_client, Client # Load environment variables load_dotenv() # Tool definitions remain unchanged @tool def multiply(a: int, b: int) -> int: return a * b @tool def add(a: int, b: int) -> int: return a + b @tool def subtract(a: int, b: int) -> int: return a - b @tool def divide(a: int, b: int) -> int: if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: return a % b @tool def wiki_search(query: str) -> str: search_docs = WikipediaLoader(query=query, load_max_docs=2).load() formatted_search_docs = "\n\n---\n\n".join( [f'\n{doc.page_content}\n' for doc in search_docs]) return {"wiki_results": formatted_search_docs} @tool def web_search(query: str) -> str: search_docs = TavilySearchResults(max_results=3).invoke(query) formatted_search_docs = "\n\n---\n\n".join( [f'\n{doc.page_content}\n' for doc in search_docs]) return {"web_results": formatted_search_docs} @tool def arvix_search(query: str) -> str: search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [f'\n{doc.page_content[:1000]}\n' for doc in search_docs]) return {"arvix_results": formatted_search_docs} # System prompt definition SYSTEM_PROMPT = """You are a helpful assistant. For every question, reply with only the answer—no explanation, no units, and no extra words. If the answer is a number, just return the number. If it is a word or phrase, return only that. If it is a list, return a comma-separated list with no extra words. Do not include any prefix, suffix, or explanation.""" sys_msg = SystemMessage(content=SYSTEM_PROMPT) # Initialize vector store embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") supabase: Client = create_client( os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"] ) vector_store = SupabaseVectorStore( client=supabase, embedding=embeddings, table_name="documents", ) tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search] # Build graph function with multi-provider support def build_graph(provider: str = "groq"): # Provider selection if provider == "google": llm = ChatGoogleGenerativeAI( model="gemini-2.0-flash", temperature=0, api_key=os.getenv("GOOGLE_API_KEY") ) elif provider == "groq": llm = ChatGroq( model="llama3-70b-8192", temperature=0, api_key=os.getenv("GROQ_API_KEY") ) elif provider == "huggingface": llm = ChatHuggingFace( llm=HuggingFaceEndpoint( endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2", temperature=0, api_key=os.getenv("HF_API_KEY") ) ) else: raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.") llm_with_tools = llm.bind_tools(tools) # Graph nodes def retriever(state: MessagesState): similar_question = vector_store.similarity_search(state["messages"][-1].content, k=1) if similar_question: example_msg = HumanMessage(content=f"Similar reference: {similar_question[0].page_content[:200]}...") return {"messages": state["messages"] + [example_msg]} return {"messages": state["messages"]} def assistant(state: MessagesState): return {"messages": [llm_with_tools.invoke(state["messages"])]} # Build graph builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") return builder.compile() # Gradio interface def run_agent(question, provider): try: graph = build_graph(provider) messages = [HumanMessage(content=question)] result = graph.invoke({"messages": messages}) final_answer = result["messages"][-1].content return final_answer except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("## LangGraph Multi-Provider Agent") provider = gr.Dropdown( choices=["groq", "google", "huggingface"], value="groq", label="LLM Provider" ) question = gr.Textbox(label="Your Question") submit_btn = gr.Button("Run Agent") output = gr.Textbox(label="Agent Response", interactive=False) submit_btn.click( fn=run_agent, inputs=[question, provider], outputs=output ) if __name__ == "__main__": demo.launch()