File size: 5,957 Bytes
d4557ee
 
 
f4505e9
e292008
 
1fa6961
f4505e9
 
0c69489
e292008
0c69489
e292008
 
f4505e9
e292008
a35ea13
f4505e9
f0e66e7
 
 
f4505e9
 
f0e66e7
 
 
1fa6961
f4505e9
f0e66e7
 
 
cc467c2
f4505e9
 
f0e66e7
f4505e9
 
 
 
 
f0e66e7
 
 
a55679f
f4505e9
 
f0e66e7
25c1140
e292008
 
a55679f
 
 
 
25c1140
d4557ee
0f81d99
f4505e9
9a3d597
f0e66e7
25c1140
e292008
d4557ee
a55679f
e292008
a55679f
 
25c1140
d4557ee
b1b6e20
a55679f
e292008
f4505e9
 
 
e292008
d4557ee
f4505e9
0c69489
e292008
 
 
 
 
a55679f
 
 
 
d4557ee
a55679f
e292008
 
 
 
 
 
 
d4557ee
e292008
 
 
a55679f
 
 
e292008
 
 
a35ea13
e292008
 
a55679f
 
e292008
0c69489
e292008
 
 
 
 
0c69489
e292008
0c69489
e292008
 
 
 
 
0c69489
e292008
0c69489
e292008
 
 
 
 
a55679f
0c69489
 
e292008
 
0c69489
0ab2059
e292008
 
 
 
 
d4557ee
0f81d99
0c69489
 
 
 
 
 
 
 
72c7dbb
0c69489
e292008
 
a35ea13
e292008
 
0c69489
 
e292008
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import time
import random
import operator
from typing import List, Dict, Any, TypedDict, Annotated
from dotenv import load_dotenv

from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_groq import ChatGroq

load_dotenv()  # expects GROQ_API_KEY in your .env

@tool
def multiply(a: int, b: int) -> int: 
    '''multiplies two numbers'''
    return a * b

@tool
def add(a: int, b: int) -> int: 
    '''adds two numbers'''
    return a + b

@tool
def subtract(a: int, b: int) -> int: 
    '''subtracts two numbers'''
    return a - b

@tool
def divide(a: int, b: int) -> float:
    '''divides two numbers'''
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulus(a: int, b: int) -> int: 
    '''returns the remainder while dividing two numbers'''
    return a % b

@tool
def optimized_web_search(query: str) -> str:
    '''searches the web using tavily'''
    try:
        time.sleep(random.uniform(0.7, 1.5))
        docs = TavilySearchResults(max_results=2).invoke(query=query)
        return "\n\n---\n\n".join(
            f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>"
            for d in docs
        )
    except Exception as e:
        return f"Web search failed: {e}"

@tool
def optimized_wiki_search(query: str) -> str:
    '''searches wikipedia'''
    try:
        time.sleep(random.uniform(0.3, 1))
        docs = WikipediaLoader(query=query, load_max_docs=1).load()
        return "\n\n---\n\n".join(
            f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:800]}</Doc>"
            for d in docs
        )
    except Exception as e:
        return f"Wikipedia search failed: {e}"

class EnhancedAgentState(TypedDict):
    messages: Annotated[List[HumanMessage | AIMessage], operator.add]
    query: str
    agent_type: str
    final_answer: str
    perf: Dict[str, Any]
    agno_resp: str

class HybridLangGraphMultiLLMSystem:
    """
    Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default),
    and Groq-hosted DeepSeek-Chat according to the query content.
    """
    def __init__(self):
        self.tools = [
            multiply, add, subtract, divide, modulus,
            optimized_web_search, optimized_wiki_search
        ]
        self.graph = self._build_graph()

    def _llm(self, model_name: str):
        return ChatGroq(
            model=model_name,
            temperature=0,
            api_key=os.getenv("GROQ_API_KEY")
        )

    def _build_graph(self):
        llama8_llm  = self._llm("llama3-8b-8192")
        llama70_llm = self._llm("llama3-70b-8192")
        deepseek_llm = self._llm("deepseek-chat")

        def router(st: EnhancedAgentState) -> EnhancedAgentState:
            q = st["query"].lower()
            if "llama-8" in q:
                t = "llama8"
            elif "deepseek" in q:
                t = "deepseek"
            else:
                t = "llama70"
            return {**st, "agent_type": t}

        def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState:
            t0 = time.time()
            sys = SystemMessage(content="You are a helpful AI assistant.")
            res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])])
            return {**st,
                    "final_answer": res.content,
                    "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}}

        def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState:
            t0 = time.time()
            sys = SystemMessage(content="You are a helpful AI assistant.")
            res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])])
            return {**st,
                    "final_answer": res.content,
                    "perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}}

        def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState:
            t0 = time.time()
            sys = SystemMessage(content="You are a helpful AI assistant.")
            res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])])
            return {**st,
                    "final_answer": res.content,
                    "perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}}

        g = StateGraph(EnhancedAgentState)
        g.add_node("router", router)
        g.add_node("llama8", llama8_node)
        g.add_node("llama70", llama70_node)
        g.add_node("deepseek", deepseek_node)
        g.set_entry_point("router")
        g.add_conditional_edges("router", lambda s: s["agent_type"],
                                {"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"})
        g.add_edge("llama8", END)
        g.add_edge("llama70", END)
        g.add_edge("deepseek", END)
        return g.compile(checkpointer=MemorySaver())

    def process_query(self, q: str) -> str:
        state = {
            "messages": [HumanMessage(content=q)],
            "query": q,
            "agent_type": "",
            "final_answer": "",
            "perf": {},
            "agno_resp": ""
        }
        cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}}
        out = self.graph.invoke(state, cfg)
        return out.get("final_answer", "").strip()

def build_graph(provider: str | None = None):
    return HybridLangGraphMultiLLMSystem().graph

if __name__ == "__main__":
    qa_system = HybridLangGraphMultiLLMSystem()
    # Test each model
    print(qa_system.process_query("llama-8: What is the capital of France?"))
    print(qa_system.process_query("llama-70: Tell me about quantum mechanics."))
    print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?"))