saritha commited on
Commit
c5ab28f
·
verified ·
1 Parent(s): 7f0a2dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -0
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import modules
2
+ from typing import TypedDict, Dict
3
+ from langgraph.graph import StateGraph, END
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.runnables.graph import MermaidDrawMethod
6
+ from IPython.display import Image, display
7
+ import gradio as gr
8
+ import os
9
+ from langchain_groq import ChatGroq
10
+
11
+ # Define the State data structure
12
+ class State(TypedDict):
13
+ query: str
14
+ category: str
15
+ sentiment: str
16
+ response: str
17
+
18
+ # Function to get the language model
19
+ def get_llm(api_key=None):
20
+ if api_key is None:
21
+ api_key = os.getenv('GROQ_API_KEY')
22
+ llm = ChatGroq(
23
+ temperature=0,
24
+ groq_api_key=api_key,
25
+ model_name="llama-3.3-70b-versatile"
26
+ )
27
+ return llm
28
+
29
+ # Define the processing functions
30
+ def categorize(state: State, llm) -> State:
31
+ prompt = ChatPromptTemplate.from_template(
32
+ "Categorize the following customer query into one of these categories: "
33
+ "Technical, Billing, General. Query: {query}"
34
+ )
35
+ chain = prompt | llm
36
+ category = chain.invoke({"query": state["query"]}).content.strip()
37
+ state["category"] = category
38
+ return state
39
+
40
+ def analyze_sentiment(state: State, llm) -> State:
41
+ prompt = ChatPromptTemplate.from_template(
42
+ "Analyze the sentiment of the following customer query. "
43
+ "Respond with either 'Positive', 'Neutral', or 'Negative'. Query: {query}"
44
+ )
45
+ chain = prompt | llm
46
+ sentiment = chain.invoke({"query": state["query"]}).content.strip()
47
+ state["sentiment"] = sentiment
48
+ return state
49
+
50
+ def handle_technical(state: State, llm) -> State:
51
+ prompt = ChatPromptTemplate.from_template(
52
+ "Provide a technical support response to the following query: {query}"
53
+ )
54
+ chain = prompt | llm
55
+ response = chain.invoke({"query": state["query"]}).content.strip()
56
+ state["response"] = response
57
+ return state
58
+
59
+ def handle_billing(state: State, llm) -> State:
60
+ prompt = ChatPromptTemplate.from_template(
61
+ "Provide a billing-related support response to the following query: {query}"
62
+ )
63
+ chain = prompt | llm
64
+ response = chain.invoke({"query": state["query"]}).content.strip()
65
+ state["response"] = response
66
+ return state
67
+
68
+ def handle_general(state: State, llm) -> State:
69
+ prompt = ChatPromptTemplate.from_template(
70
+ "Provide a general support response to the following query: {query}"
71
+ )
72
+ chain = prompt | llm
73
+ response = chain.invoke({"query": state["query"]}).content.strip()
74
+ state["response"] = response
75
+ return state
76
+
77
+ def escalate(state: State) -> State:
78
+ state["response"] = "This query has been escalated to a human agent due to its negative sentiment."
79
+ return state
80
+
81
+ def route_query(state: State) -> str:
82
+ if state["sentiment"].lower() == "negative":
83
+ return "escalate"
84
+ elif state["category"].lower() == "technical":
85
+ return "handle_technical"
86
+ elif state["category"].lower() == "billing":
87
+ return "handle_billing"
88
+ else:
89
+ return "handle_general"
90
+
91
+ # Function to compile the workflow
92
+ def get_workflow(llm):
93
+ workflow = StateGraph(State)
94
+ workflow.add_node("categorize", lambda state: categorize(state, llm))
95
+ workflow.add_node("analyze_sentiment", lambda state: analyze_sentiment(state, llm))
96
+ workflow.add_node("handle_technical", lambda state: handle_technical(state, llm))
97
+ workflow.add_node("handle_billing", lambda state: handle_billing(state, llm))
98
+ workflow.add_node("handle_general", lambda state: handle_general(state, llm))
99
+ workflow.add_node("escalate", escalate)
100
+
101
+ workflow.add_edge("categorize", "analyze_sentiment")
102
+ workflow.add_conditional_edges("analyze_sentiment",
103
+ route_query, {
104
+ "handle_technical": "handle_technical",
105
+ "handle_billing": "handle_billing",
106
+ "handle_general": "handle_general",
107
+ "escalate": "escalate",
108
+ })
109
+ workflow.add_edge("handle_technical", END)
110
+ workflow.add_edge("handle_billing", END)
111
+ workflow.add_edge("handle_general", END)
112
+ workflow.add_edge("escalate", END)
113
+
114
+ workflow.set_entry_point("categorize")
115
+ return workflow.compile()
116
+
117
+ # Gradio interface function
118
+ def run_customer_support(query: str, api_key: str) -> Dict[str, str]:
119
+ llm = get_llm(api_key)
120
+ app = get_workflow(llm)
121
+ result = app.invoke({"query": query})
122
+ return {
123
+ "Query": query,
124
+ "Category": result.get("category", "").strip(),
125
+ "Sentiment": result.get("sentiment", "").strip(),
126
+ "Response": result.get("response", "").strip()
127
+ }
128
+
129
+ # Create the Gradio interface
130
+ gr_interface = gr.Interface(
131
+ fn=run_customer_support,
132
+ inputs=[
133
+ gr.inputs.Textbox(lines=2, label="Customer Query", placeholder="Enter your customer support query here..."),
134
+ gr.inputs.Password(label="GROQ API Key", placeholder="Enter your GROQ API key"),
135
+ ],
136
+ outputs=gr.outputs.JSON(label="Response"),
137
+ title="Customer Support Chatbot",
138
+ description="Enter your query to receive assistance.",
139
+ )
140
+
141
+ # Launch the Gradio interface
142
+ gr_interface.launch()