BaristaBot / app.py
APRG's picture
Update app.py
2dc3ffd verified
#pip install langchain_google_genai langgraph gradio
import os
import sys
import typing
from typing import Annotated, Literal, Iterable
from typing_extensions import TypedDict
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage, SystemMessage
from random import randint
from tkinter import messagebox
#messagebox.showinfo("Test", "Script run successfully")
import gradio as gr
import logging
class OrderState(TypedDict):
"""State representing the customer's order conversation."""
messages: Annotated[list, add_messages]
order: list[str]
finished: bool
# System instruction for the BaristaBot
BARISTABOT_SYSINT = (
"system",
"You are a BaristaBot, an interactive cafe ordering system. A human will talk to you about the "
"available products. Answer questions about menu items, help customers place orders, and "
"confirm details before finalizing. Use the provided tools to manage the order."
)
WELCOME_MSG = "Welcome to the BaristaBot cafe. Type `q` to quit. How may I serve you today?"
# Initialize the Google Gemini LLM
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest")
@tool
def get_menu() -> str:
"""Provide the cafe menu."""
#messagebox.showinfo("Test", "Script run successfully")
with open("menu.txt", 'r', encoding = "UTF-8") as f:
return f.read()
@tool
def add_to_order(drink: str, modifiers: Iterable[str] = []) -> str:
"""Adds the specified drink to the customer's order."""
return f"{drink} ({', '.join(modifiers) if modifiers else 'no modifiers'})"
@tool
def confirm_order() -> str:
"""Asks the customer to confirm the order."""
return "Order confirmation requested"
@tool
def get_order() -> str:
"""Returns the current order."""
return "Current order details requested"
@tool
def clear_order() -> str:
"""Clears the current order."""
return "Order cleared"
@tool
def place_order() -> int:
"""Sends the order to the kitchen."""
#messagebox.showinfo("Test", "Order successful!")
return randint(2, 10) # Estimated wait time
def chatbot_with_tools(state: OrderState) -> OrderState:
"""Chatbot with tool handling."""
logging.info(f"Messagelist sent to chatbot node: {[msg.content for msg in state.get('messages', [])]}")
defaults = {"order": [], "finished": False}
# Ensure we always have at least a system message
if not state.get("messages", []):
new_output = AIMessage(content=WELCOME_MSG)
return defaults | state | {"messages": [SystemMessage(content=BARISTABOT_SYSINT), new_output]}
try:
# Prepend system instruction if not already present
messages_with_system = [
SystemMessage(content=BARISTABOT_SYSINT)
] + state.get("messages", [])
# Process messages through the LLM
new_output = llm_with_tools.invoke(messages_with_system)
return defaults | state | {"messages": [new_output]}
except Exception as e:
# Fallback if LLM processing fails
return defaults | state | {"messages": [AIMessage(content=f"I'm having trouble processing that. {str(e)}")]}
def order_node(state: OrderState) -> OrderState:
"""Handles order-related tool calls."""
logging.info("order node")
tool_msg = state.get("messages", [])[-1]
order = state.get("order", [])
outbound_msgs = []
order_placed = False
for tool_call in tool_msg.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
if tool_name == "add_to_order":
modifiers = tool_args.get("modifiers", [])
modifier_str = ", ".join(modifiers) if modifiers else "no modifiers"
order.append(f'{tool_args["drink"]} ({modifier_str})')
response = "\n".join(order)
elif tool_name == "confirm_order":
response = "Your current order:\n" + "\n".join(order) + "\nIs this correct?"
elif tool_name == "get_order":
response = "\n".join(order) if order else "(no order)"
elif tool_name == "clear_order":
order.clear()
response = "Order cleared"
elif tool_name == "place_order":
order_text = "\n".join(order)
order_placed = True
response = f"Order placed successfully!\nYour order:\n{order_text}\nEstimated wait: {randint(2, 10)} minutes"
else:
raise NotImplementedError(f'Unknown tool call: {tool_name}')
outbound_msgs.append(
ToolMessage(
content=response,
name=tool_name,
tool_call_id=tool_call["id"],
)
)
return {"messages": outbound_msgs, "order": order, "finished": order_placed}
def maybe_route_to_tools(state: OrderState) -> str:
"""Route between chat and tool nodes."""
if not (msgs := state.get("messages", [])):
raise ValueError(f"No messages found when parsing state: {state}")
msg = msgs[-1]
if state.get("finished", False):
logging.info("from chatbot GOTO End node")
return END
elif hasattr(msg, "tool_calls") and len(msg.tool_calls) > 0:
if any(tool["name"] in tool_node.tools_by_name.keys() for tool in msg.tool_calls):
logging.info("from chatbot GOTO tools node")
return "tools"
else:
logging.info("from chatbot GOTO order node")
return "ordering"
else:
logging.info("from chatbot GOTO human node")
return "human"
def human_node(state: OrderState) -> OrderState:
"""Handle user input."""
logging.info(f"Messagelist sent to human node: {[msg.content for msg in state.get('messages', [])]}")
last_msg = state["messages"][-1]
if last_msg.content.lower() in {"q", "quit", "exit", "goodbye"}:
state["finished"] = True
return state
def maybe_exit_human_node(state: OrderState) -> Literal["chatbot", "__end__"]:
"""Determine if conversation should continue."""
if state.get("finished", False):
logging.info("from human GOTO End node")
return END
last_msg = state["messages"][-1]
if isinstance(last_msg, AIMessage):
logging.info("Chatbot response obtained, ending conversation")
return END
else:
logging.info("from human GOTO chatbot node")
return "chatbot"
# Prepare tools
auto_tools = [get_menu]
tool_node = ToolNode(auto_tools)
order_tools = [add_to_order, confirm_order, get_order, clear_order, place_order]
# Bind all tools to the LLM
llm_with_tools = llm.bind_tools(auto_tools + order_tools)
# Build the graph
graph_builder = StateGraph(OrderState)
# Add nodes
graph_builder.add_node("chatbot", chatbot_with_tools)
graph_builder.add_node("human", human_node)
graph_builder.add_node("tools", tool_node)
graph_builder.add_node("ordering", order_node)
# Add edges and routing
graph_builder.add_conditional_edges("chatbot", maybe_route_to_tools)
graph_builder.add_conditional_edges("human", maybe_exit_human_node)
graph_builder.add_edge("tools", "chatbot")
graph_builder.add_edge("ordering", "chatbot")
graph_builder.add_edge(START, "human")
# Compile the graph
chat_graph = graph_builder.compile()
def convert_history_to_messages(history: list) -> list[BaseMessage]:
"""
Convert Gradio chat history to a list of Langchain messages.
Args:
- history: Gradio's chat history format
Returns:
- List of Langchain BaseMessage objects
"""
messages = []
for human, ai in history:
if human:
messages.append(HumanMessage(content=human))
if ai:
messages.append(AIMessage(content=ai))
return messages
def gradio_chat(message: str, history: list) -> str:
"""
Gradio-compatible chat function that manages the conversation state.
Args:
- message: User's input message
- history: Gradio's chat history
Returns:
- Bot's response as a string
"""
logging.info(f"{len(history)} history so far: {history}")
# Ensure non-empty message
if not message or message.strip() == "":
message = "Hello, how can I help you today?"
# Convert history to Langchain messages
conversation_messages = []
for old_message in history:
if old_message["content"].strip():
if old_message["role"] == "user":
conversation_messages.append(HumanMessage(content=old_message["content"]))
if old_message["role"] == "assistant":
conversation_messages.append(AIMessage(content=old_message["content"]))
# Add current message
conversation_messages.append(HumanMessage(content=message))
# Create initial state with conversation history
conversation_state = {
"messages": conversation_messages,
"order": [],
"finished": False
}
logging.info(f"Conversation so far: {str(conversation_state)}")
try:
# Process the conversation through the graph
conversation_state = chat_graph.invoke(conversation_state, {"recursion_limit": 10})
# Extract the latest bot message
latest_message = conversation_state["messages"][-1]
# Return the bot's response content
logging.info(f"return: {latest_message.content}")
return latest_message.content
except Exception as e:
return f"An error occurred: {str(e)}"
# Gradio interface
def launch_baristabot():
gr.ChatInterface(
gradio_chat,
type="messages",
title="BaristaBot",
description="Your friendly AI cafe assistant",
theme="ocean"
).launch()
if __name__ == "__main__":
# initiate logging tool
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
launch_baristabot()