Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import chainlit as cl | |
from dotenv import load_dotenv | |
from document_processing import DocumentManager | |
from retrieval import RetrievalManager | |
from langchain_core.messages import AIMessage, HumanMessage | |
from graph import create_tutor_chain, TutorState | |
# Load environment variables | |
load_dotenv() | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
async def start_chat(): | |
settings = { | |
"model": "gpt-3.5-turbo", | |
"temperature": 0, | |
"top_p": 1, | |
"frequency_penalty": 0, | |
"presence_penalty": 0, | |
} | |
cl.user_session.set("settings", settings) | |
welcome_message = "Welcome to the Notebook-Tutor! Please upload a Jupyter notebook (.ipynb and max. 5mb) to start." | |
await cl.Message(content=welcome_message).send() | |
files = None | |
while files is None: | |
files = await cl.AskFileMessage( | |
content="Please upload a Jupyter notebook (.ipynb, max. 5mb):", | |
accept={"application/x-ipynb+json": [".ipynb"]}, | |
max_size_mb=5 | |
).send() | |
file = files[0] # Get the first file | |
if file: | |
notebook_path = file.path | |
doc_manager = DocumentManager(notebook_path) | |
doc_manager.load_document() | |
doc_manager.initialize_retriever() | |
cl.user_session.set("docs", doc_manager.get_documents()) | |
cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever())) | |
# Initialize LangGraph chain with the retrieval chain | |
retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain() | |
cl.user_session.set("retrieval_chain", retrieval_chain) | |
tutor_chain = create_tutor_chain(retrieval_chain) | |
cl.user_session.set("tutor_chain", tutor_chain) | |
logger.info("Chat started and notebook uploaded successfully.") | |
async def main(message: cl.Message): | |
# Retrieve the LangGraph chain from the session | |
tutor_chain = cl.user_session.get("tutor_chain") | |
if not tutor_chain: | |
await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send() | |
return | |
# Create the initial state with the user message | |
user_message = message.content | |
state = TutorState( | |
messages=[HumanMessage(content=user_message)], | |
next="supervisor", | |
quiz=[], | |
quiz_created=False, | |
question_answered=False, | |
flashcards_created=False, | |
flashcard_filename="", | |
) | |
print(f"Initial state: {state}") | |
# Process the message through the LangGraph chain | |
for s in tutor_chain.stream(state, {"recursion_limit": 10}): | |
print(f"State after processing: {s}") | |
# Extract messages from the state | |
if "__end__" not in s: | |
agent_state = next(iter(s.values())) | |
if "messages" in agent_state: | |
response = agent_state["messages"][-1].content | |
print(f"Response: {response}") | |
await cl.Message(content=response).send() | |
else: | |
print("Error: No messages found in agent state.") | |
else: | |
# Extract the final state | |
final_state = next(iter(s.values())) | |
# Check if the quiz was created and send it to the frontend | |
if final_state.get("quiz_created"): | |
quiz_message = final_state["messages"][-1].content | |
await cl.Message(content=quiz_message).send() | |
# Check if a question was answered and send the response to the frontend | |
if final_state.get("question_answered"): | |
qa_message = final_state["messages"][-1].content | |
await cl.Message(content=qa_message).send() | |
# Check if flashcards are ready and send the file to the frontend | |
if final_state.get("flashcards_created"): | |
flashcards_message = final_state["messages"][-1].content | |
await cl.Message(content=flashcards_message).send() | |
# Create a full path to the file | |
flashcard_filename = final_state["flashcard_filename"] | |
print(f"Flashcard filename: {flashcard_filename}") | |
flashcard_path = os.path.abspath(flashcard_filename) | |
print(f"Flashcard path: {flashcard_path}") | |
# Use the File class to send the file | |
file_element = cl.File(name=os.path.basename(flashcard_filename), path=flashcard_path) | |
print(f"Sending flashcards file: {file_element}") | |
await cl.Message( | |
content="Here are your flashcards:", | |
elements=[file_element] | |
).send() | |
print("Reached END state.") | |
break | |