Spaces:
Sleeping
Sleeping
#pip install langchain_google_genai langgraph gradio | |
import os | |
import sys | |
import typing | |
from typing import Annotated, Literal, Iterable | |
from typing_extensions import TypedDict | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langgraph.graph import StateGraph, START, END | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.tools import tool | |
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage, SystemMessage | |
from random import randint | |
from tkinter import messagebox | |
#messagebox.showinfo("Test", "Script run successfully") | |
import gradio as gr | |
import logging | |
class OrderState(TypedDict): | |
"""State representing the customer's order conversation.""" | |
messages: Annotated[list, add_messages] | |
order: list[str] | |
finished: bool | |
# System instruction for the BaristaBot | |
BARISTABOT_SYSINT = ( | |
"system", | |
"You are a BaristaBot, an interactive cafe ordering system. A human will talk to you about the " | |
"available products. Answer questions about menu items, help customers place orders, and " | |
"confirm details before finalizing. Use the provided tools to manage the order." | |
) | |
WELCOME_MSG = "Welcome to the BaristaBot cafe. Type `q` to quit. How may I serve you today?" | |
# Initialize the Google Gemini LLM | |
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest") | |
def get_menu() -> str: | |
"""Provide the cafe menu.""" | |
#messagebox.showinfo("Test", "Script run successfully") | |
with open("menu.txt", 'r', encoding = "UTF-8") as f: | |
return f.read() | |
def add_to_order(drink: str, modifiers: Iterable[str] = []) -> str: | |
"""Adds the specified drink to the customer's order.""" | |
return f"{drink} ({', '.join(modifiers) if modifiers else 'no modifiers'})" | |
def confirm_order() -> str: | |
"""Asks the customer to confirm the order.""" | |
return "Order confirmation requested" | |
def get_order() -> str: | |
"""Returns the current order.""" | |
return "Current order details requested" | |
def clear_order() -> str: | |
"""Clears the current order.""" | |
return "Order cleared" | |
def place_order() -> int: | |
"""Sends the order to the kitchen.""" | |
#messagebox.showinfo("Test", "Order successful!") | |
return randint(2, 10) # Estimated wait time | |
def chatbot_with_tools(state: OrderState) -> OrderState: | |
"""Chatbot with tool handling.""" | |
logging.info(f"Messagelist sent to chatbot node: {[msg.content for msg in state.get('messages', [])]}") | |
defaults = {"order": [], "finished": False} | |
# Ensure we always have at least a system message | |
if not state.get("messages", []): | |
new_output = AIMessage(content=WELCOME_MSG) | |
return defaults | state | {"messages": [SystemMessage(content=BARISTABOT_SYSINT), new_output]} | |
try: | |
# Prepend system instruction if not already present | |
messages_with_system = [ | |
SystemMessage(content=BARISTABOT_SYSINT) | |
] + state.get("messages", []) | |
# Process messages through the LLM | |
new_output = llm_with_tools.invoke(messages_with_system) | |
return defaults | state | {"messages": [new_output]} | |
except Exception as e: | |
# Fallback if LLM processing fails | |
return defaults | state | {"messages": [AIMessage(content=f"I'm having trouble processing that. {str(e)}")]} | |
def order_node(state: OrderState) -> OrderState: | |
"""Handles order-related tool calls.""" | |
logging.info("order node") | |
tool_msg = state.get("messages", [])[-1] | |
order = state.get("order", []) | |
outbound_msgs = [] | |
order_placed = False | |
for tool_call in tool_msg.tool_calls: | |
tool_name = tool_call["name"] | |
tool_args = tool_call["args"] | |
if tool_name == "add_to_order": | |
modifiers = tool_args.get("modifiers", []) | |
modifier_str = ", ".join(modifiers) if modifiers else "no modifiers" | |
order.append(f'{tool_args["drink"]} ({modifier_str})') | |
response = "\n".join(order) | |
elif tool_name == "confirm_order": | |
response = "Your current order:\n" + "\n".join(order) + "\nIs this correct?" | |
elif tool_name == "get_order": | |
response = "\n".join(order) if order else "(no order)" | |
elif tool_name == "clear_order": | |
order.clear() | |
response = "Order cleared" | |
elif tool_name == "place_order": | |
order_text = "\n".join(order) | |
order_placed = True | |
response = f"Order placed successfully!\nYour order:\n{order_text}\nEstimated wait: {randint(2, 10)} minutes" | |
else: | |
raise NotImplementedError(f'Unknown tool call: {tool_name}') | |
outbound_msgs.append( | |
ToolMessage( | |
content=response, | |
name=tool_name, | |
tool_call_id=tool_call["id"], | |
) | |
) | |
return {"messages": outbound_msgs, "order": order, "finished": order_placed} | |
def maybe_route_to_tools(state: OrderState) -> str: | |
"""Route between chat and tool nodes.""" | |
if not (msgs := state.get("messages", [])): | |
raise ValueError(f"No messages found when parsing state: {state}") | |
msg = msgs[-1] | |
if state.get("finished", False): | |
logging.info("from chatbot GOTO End node") | |
return END | |
elif hasattr(msg, "tool_calls") and len(msg.tool_calls) > 0: | |
if any(tool["name"] in tool_node.tools_by_name.keys() for tool in msg.tool_calls): | |
logging.info("from chatbot GOTO tools node") | |
return "tools" | |
else: | |
logging.info("from chatbot GOTO order node") | |
return "ordering" | |
else: | |
logging.info("from chatbot GOTO human node") | |
return "human" | |
def human_node(state: OrderState) -> OrderState: | |
"""Handle user input.""" | |
logging.info(f"Messagelist sent to human node: {[msg.content for msg in state.get('messages', [])]}") | |
last_msg = state["messages"][-1] | |
if last_msg.content.lower() in {"q", "quit", "exit", "goodbye"}: | |
state["finished"] = True | |
return state | |
def maybe_exit_human_node(state: OrderState) -> Literal["chatbot", "__end__"]: | |
"""Determine if conversation should continue.""" | |
if state.get("finished", False): | |
logging.info("from human GOTO End node") | |
return END | |
last_msg = state["messages"][-1] | |
if isinstance(last_msg, AIMessage): | |
logging.info("Chatbot response obtained, ending conversation") | |
return END | |
else: | |
logging.info("from human GOTO chatbot node") | |
return "chatbot" | |
# Prepare tools | |
auto_tools = [get_menu] | |
tool_node = ToolNode(auto_tools) | |
order_tools = [add_to_order, confirm_order, get_order, clear_order, place_order] | |
# Bind all tools to the LLM | |
llm_with_tools = llm.bind_tools(auto_tools + order_tools) | |
# Build the graph | |
graph_builder = StateGraph(OrderState) | |
# Add nodes | |
graph_builder.add_node("chatbot", chatbot_with_tools) | |
graph_builder.add_node("human", human_node) | |
graph_builder.add_node("tools", tool_node) | |
graph_builder.add_node("ordering", order_node) | |
# Add edges and routing | |
graph_builder.add_conditional_edges("chatbot", maybe_route_to_tools) | |
graph_builder.add_conditional_edges("human", maybe_exit_human_node) | |
graph_builder.add_edge("tools", "chatbot") | |
graph_builder.add_edge("ordering", "chatbot") | |
graph_builder.add_edge(START, "human") | |
# Compile the graph | |
chat_graph = graph_builder.compile() | |
def convert_history_to_messages(history: list) -> list[BaseMessage]: | |
""" | |
Convert Gradio chat history to a list of Langchain messages. | |
Args: | |
- history: Gradio's chat history format | |
Returns: | |
- List of Langchain BaseMessage objects | |
""" | |
messages = [] | |
for human, ai in history: | |
if human: | |
messages.append(HumanMessage(content=human)) | |
if ai: | |
messages.append(AIMessage(content=ai)) | |
return messages | |
def gradio_chat(message: str, history: list) -> str: | |
""" | |
Gradio-compatible chat function that manages the conversation state. | |
Args: | |
- message: User's input message | |
- history: Gradio's chat history | |
Returns: | |
- Bot's response as a string | |
""" | |
logging.info(f"{len(history)} history so far: {history}") | |
# Ensure non-empty message | |
if not message or message.strip() == "": | |
message = "Hello, how can I help you today?" | |
# Convert history to Langchain messages | |
conversation_messages = [] | |
for old_message in history: | |
if old_message["content"].strip(): | |
if old_message["role"] == "user": | |
conversation_messages.append(HumanMessage(content=old_message["content"])) | |
if old_message["role"] == "assistant": | |
conversation_messages.append(AIMessage(content=old_message["content"])) | |
# Add current message | |
conversation_messages.append(HumanMessage(content=message)) | |
# Create initial state with conversation history | |
conversation_state = { | |
"messages": conversation_messages, | |
"order": [], | |
"finished": False | |
} | |
logging.info(f"Conversation so far: {str(conversation_state)}") | |
try: | |
# Process the conversation through the graph | |
conversation_state = chat_graph.invoke(conversation_state, {"recursion_limit": 10}) | |
# Extract the latest bot message | |
latest_message = conversation_state["messages"][-1] | |
# Return the bot's response content | |
logging.info(f"return: {latest_message.content}") | |
return latest_message.content | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Gradio interface | |
def launch_baristabot(): | |
gr.ChatInterface( | |
gradio_chat, | |
type="messages", | |
title="BaristaBot", | |
description="Your friendly AI cafe assistant", | |
theme="ocean" | |
).launch() | |
if __name__ == "__main__": | |
# initiate logging tool | |
logging.basicConfig( | |
stream=sys.stdout, | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
launch_baristabot() |