|
|
|
from typing import TypedDict, Dict |
|
from langgraph.graph import StateGraph, END |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.runnables.graph import MermaidDrawMethod |
|
|
|
import gradio as gr |
|
import os |
|
from langchain_groq import ChatGroq |
|
|
|
|
|
class State(TypedDict): |
|
query: str |
|
category: str |
|
sentiment: str |
|
response: str |
|
|
|
|
|
def get_llm(api_key=None): |
|
if api_key is None: |
|
api_key = os.getenv('GROQ_API_KEY') |
|
llm = ChatGroq( |
|
temperature=0, |
|
groq_api_key=api_key, |
|
model_name="llama-3.3-70b-versatile" |
|
) |
|
return llm |
|
|
|
|
|
def categorize(state: State, llm) -> State: |
|
prompt = ChatPromptTemplate.from_template( |
|
"Categorize the following customer query into one of these categories: " |
|
"Technical, Billing, General. Query: {query}" |
|
) |
|
chain = prompt | llm |
|
category = chain.invoke({"query": state["query"]}).content.strip() |
|
state["category"] = category |
|
return state |
|
|
|
def analyze_sentiment(state: State, llm) -> State: |
|
prompt = ChatPromptTemplate.from_template( |
|
"Analyze the sentiment of the following customer query. " |
|
"Respond with either 'Positive', 'Neutral', or 'Negative'. Query: {query}" |
|
) |
|
chain = prompt | llm |
|
sentiment = chain.invoke({"query": state["query"]}).content.strip() |
|
state["sentiment"] = sentiment |
|
return state |
|
|
|
def handle_technical(state: State, llm) -> State: |
|
prompt = ChatPromptTemplate.from_template( |
|
"Provide a technical support response to the following query: {query}" |
|
) |
|
chain = prompt | llm |
|
response = chain.invoke({"query": state["query"]}).content.strip() |
|
state["response"] = response |
|
return state |
|
|
|
def handle_billing(state: State, llm) -> State: |
|
prompt = ChatPromptTemplate.from_template( |
|
"Provide a billing-related support response to the following query: {query}" |
|
) |
|
chain = prompt | llm |
|
response = chain.invoke({"query": state["query"]}).content.strip() |
|
state["response"] = response |
|
return state |
|
|
|
def handle_general(state: State, llm) -> State: |
|
prompt = ChatPromptTemplate.from_template( |
|
"Provide a general support response to the following query: {query}" |
|
) |
|
chain = prompt | llm |
|
response = chain.invoke({"query": state["query"]}).content.strip() |
|
state["response"] = response |
|
return state |
|
|
|
def escalate(state: State) -> State: |
|
state["response"] = "This query has been escalated to a human agent due to its negative sentiment." |
|
return state |
|
|
|
def route_query(state: State) -> str: |
|
if state["sentiment"].lower() == "negative": |
|
return "escalate" |
|
elif state["category"].lower() == "technical": |
|
return "handle_technical" |
|
elif state["category"].lower() == "billing": |
|
return "handle_billing" |
|
else: |
|
return "handle_general" |
|
|
|
|
|
def get_workflow(llm): |
|
workflow = StateGraph(State) |
|
workflow.add_node("categorize", lambda state: categorize(state, llm)) |
|
workflow.add_node("analyze_sentiment", lambda state: analyze_sentiment(state, llm)) |
|
workflow.add_node("handle_technical", lambda state: handle_technical(state, llm)) |
|
workflow.add_node("handle_billing", lambda state: handle_billing(state, llm)) |
|
workflow.add_node("handle_general", lambda state: handle_general(state, llm)) |
|
workflow.add_node("escalate", escalate) |
|
|
|
workflow.add_edge("categorize", "analyze_sentiment") |
|
workflow.add_conditional_edges("analyze_sentiment", |
|
route_query, { |
|
"handle_technical": "handle_technical", |
|
"handle_billing": "handle_billing", |
|
"handle_general": "handle_general", |
|
"escalate": "escalate", |
|
}) |
|
workflow.add_edge("handle_technical", END) |
|
workflow.add_edge("handle_billing", END) |
|
workflow.add_edge("handle_general", END) |
|
workflow.add_edge("escalate", END) |
|
|
|
workflow.set_entry_point("categorize") |
|
return workflow.compile() |
|
|
|
|
|
def run_customer_support(query: str, api_key: str) -> Dict[str, str]: |
|
llm = get_llm(api_key) |
|
app = get_workflow(llm) |
|
result = app.invoke({"query": query}) |
|
return { |
|
|
|
|
|
|
|
"Response": result.get("response", "").strip() |
|
} |
|
|
|
|
|
gr_interface = gr.Interface( |
|
fn=run_customer_support, |
|
inputs=[ |
|
gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your customer support query here..."), |
|
gr.Textbox(label="GROQ API Key", placeholder="Enter your GROQ API key"), |
|
], |
|
outputs=gr.JSON(label="Response"), |
|
title="Customer Support Chatbot", |
|
description="Enter your query to receive assistance.", |
|
) |
|
|
|
|
|
gr_interface.launch() |