DrishtiSharma commited on
Commit
7054888
·
verified ·
1 Parent(s): 9e8133a

Create interim.py

Browse files
Files changed (1) hide show
  1. interim.py +87 -0
interim.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import operator
4
+ import streamlit as st
5
+ from typing import TypedDict, Annotated, Sequence
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.tools import tool
8
+ from langchain_core.utils.function_calling import convert_to_openai_tool
9
+ from langgraph.graph import StateGraph, END
10
+
11
+ # Environment Setup
12
+ os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
13
+
14
+ # Model Initialization
15
+ model = ChatOpenAI(temperature=0)
16
+
17
+ @tool
18
+ def multiply(first_number: int, second_number: int):
19
+ """Multiplies two numbers together and returns the result."""
20
+ return first_number * second_number
21
+
22
+ model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
23
+
24
+ # State Setup
25
+ class AgentState(TypedDict):
26
+ messages: Annotated[Sequence, operator.add]
27
+
28
+ graph = StateGraph(AgentState)
29
+
30
+ def invoke_model(state):
31
+ question = state['messages'][-1]
32
+ return {"messages": [model_with_tools.invoke(question)]}
33
+
34
+ graph.add_node("agent", invoke_model)
35
+
36
+ def invoke_tool(state):
37
+ tool_calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
38
+ for tool_call in tool_calls:
39
+ if tool_call.get("function").get("name") == "multiply":
40
+ res = multiply.invoke(json.loads(tool_call.get("function").get("arguments")))
41
+ return {"messages": [f"Tool Result: {res}"]}
42
+ return {"messages": ["No tool input provided."]}
43
+
44
+ graph.add_node("tool", invoke_tool)
45
+ graph.add_edge("tool", END)
46
+ graph.set_entry_point("agent")
47
+
48
+ def router(state):
49
+ calls = state['messages'][-1].additional_kwargs.get("tool_calls", [])
50
+ return "multiply" if calls else "end"
51
+
52
+ graph.add_conditional_edges("agent", router, {"multiply": "tool", "end": END})
53
+ app_graph = graph.compile()
54
+
55
+ # Streamlit Interface
56
+ st.title("Multiplication Tool")
57
+
58
+ tab1, tab2 = st.tabs(["Tool Showcase", "Ask General Queries"])
59
+
60
+ with tab1:
61
+ st.subheader("Try Multiplication")
62
+ col1, col2 = st.columns(2)
63
+
64
+ with col1:
65
+ first_number = st.number_input("First Number", value=0, step=1)
66
+ with col2:
67
+ second_number = st.number_input("Second Number", value=0, step=1)
68
+
69
+ if st.button("Multiply"):
70
+ question = f"What is {first_number} * {second_number}?"
71
+ output = app_graph.invoke({"messages": [question]})
72
+ st.success(output['messages'][-1])
73
+
74
+ with tab2:
75
+ st.subheader("General Query")
76
+ user_input = st.text_input("Enter your question here")
77
+
78
+ if st.button("Submit"):
79
+ if user_input:
80
+ try:
81
+ result = app_graph.invoke({"messages": [user_input]})
82
+ st.write("Response:")
83
+ st.success(result['messages'][-1])
84
+ except Exception as e:
85
+ st.error("Something went wrong. Try again!")
86
+ else:
87
+ st.warning("Please enter a valid input.")