AI-Notebook-Tutor / notebook_tutor /chainlit_frontend.py
JulsdL's picture
Implement Flashcard creation tool and update project naming
48d9af7
raw
history blame
4.82 kB
import os
import logging
import chainlit as cl
from dotenv import load_dotenv
from document_processing import DocumentManager
from retrieval import RetrievalManager
from langchain_core.messages import AIMessage, HumanMessage
from graph import create_tutor_chain, TutorState
# Load environment variables
load_dotenv()
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@cl.on_chat_start
async def start_chat():
settings = {
"model": "gpt-3.5-turbo",
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
}
cl.user_session.set("settings", settings)
welcome_message = "Welcome to the Notebook-Tutor! Please upload a Jupyter notebook (.ipynb and max. 5mb) to start."
await cl.Message(content=welcome_message).send()
files = None
while files is None:
files = await cl.AskFileMessage(
content="Please upload a Jupyter notebook (.ipynb, max. 5mb):",
accept={"application/x-ipynb+json": [".ipynb"]},
max_size_mb=5
).send()
file = files[0] # Get the first file
if file:
notebook_path = file.path
doc_manager = DocumentManager(notebook_path)
doc_manager.load_document()
doc_manager.initialize_retriever()
cl.user_session.set("docs", doc_manager.get_documents())
cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))
# Initialize LangGraph chain with the retrieval chain
retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
cl.user_session.set("retrieval_chain", retrieval_chain)
tutor_chain = create_tutor_chain(retrieval_chain)
cl.user_session.set("tutor_chain", tutor_chain)
logger.info("Chat started and notebook uploaded successfully.")
@cl.on_message
async def main(message: cl.Message):
# Retrieve the LangGraph chain from the session
tutor_chain = cl.user_session.get("tutor_chain")
if not tutor_chain:
await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
return
# Create the initial state with the user message
user_message = message.content
state = TutorState(
messages=[HumanMessage(content=user_message)],
next="supervisor",
quiz=[],
quiz_created=False,
question_answered=False,
flashcards_created=False,
flashcard_filename="",
)
print(f"Initial state: {state}")
# Process the message through the LangGraph chain
for s in tutor_chain.stream(state, {"recursion_limit": 10}):
print(f"State after processing: {s}")
# Extract messages from the state
if "__end__" not in s:
agent_state = next(iter(s.values()))
if "messages" in agent_state:
response = agent_state["messages"][-1].content
print(f"Response: {response}")
await cl.Message(content=response).send()
else:
print("Error: No messages found in agent state.")
else:
# Extract the final state
final_state = next(iter(s.values()))
# Check if the quiz was created and send it to the frontend
if final_state.get("quiz_created"):
quiz_message = final_state["messages"][-1].content
await cl.Message(content=quiz_message).send()
# Check if a question was answered and send the response to the frontend
if final_state.get("question_answered"):
qa_message = final_state["messages"][-1].content
await cl.Message(content=qa_message).send()
# Check if flashcards are ready and send the file to the frontend
if final_state.get("flashcards_created"):
flashcards_message = final_state["messages"][-1].content
await cl.Message(content=flashcards_message).send()
# Create a full path to the file
flashcard_filename = final_state["flashcard_filename"]
print(f"Flashcard filename: {flashcard_filename}")
flashcard_path = os.path.abspath(flashcard_filename)
print(f"Flashcard path: {flashcard_path}")
# Use the File class to send the file
file_element = cl.File(name=os.path.basename(flashcard_filename), path=flashcard_path)
print(f"Sending flashcards file: {file_element}")
await cl.Message(
content="Here are your flashcards:",
elements=[file_element]
).send()
print("Reached END state.")
break