File size: 4,821 Bytes
48d9af7
 
c4eb0c2
 
 
 
deeba11
48d9af7
c4eb0c2
 
 
 
48d9af7
 
 
 
 
c4eb0c2
 
 
 
 
 
 
 
 
 
48d9af7
c4eb0c2
 
 
 
 
 
 
 
 
 
 
4e8b6ab
c4eb0c2
 
 
 
 
 
 
 
deeba11
 
4e8b6ab
48d9af7
 
 
 
deeba11
c4eb0c2
 
deeba11
48d9af7
deeba11
48d9af7
c4eb0c2
 
 
deeba11
 
48d9af7
 
 
 
 
 
 
 
 
deeba11
 
 
 
48d9af7
deeba11
c4eb0c2
deeba11
 
 
 
 
 
 
 
 
 
48d9af7
 
 
4e8b6ab
48d9af7
 
4e8b6ab
48d9af7
4e8b6ab
48d9af7
 
4e8b6ab
 
48d9af7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e8b6ab
deeba11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import logging
import chainlit as cl
from dotenv import load_dotenv
from document_processing import DocumentManager
from retrieval import RetrievalManager
from langchain_core.messages import AIMessage, HumanMessage
from graph import create_tutor_chain, TutorState

# Load environment variables
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)

@cl.on_chat_start
async def start_chat():
    settings = {
        "model": "gpt-3.5-turbo",
        "temperature": 0,
        "top_p": 1,
        "frequency_penalty": 0,
        "presence_penalty": 0,
    }
    cl.user_session.set("settings", settings)
    welcome_message = "Welcome to the Notebook-Tutor! Please upload a Jupyter notebook (.ipynb and max. 5mb) to start."
    await cl.Message(content=welcome_message).send()

    files = None
    while files is None:
        files = await cl.AskFileMessage(
            content="Please upload a Jupyter notebook (.ipynb, max. 5mb):",
            accept={"application/x-ipynb+json": [".ipynb"]},
            max_size_mb=5
        ).send()

    file = files[0]  # Get the first file

    if file:
        notebook_path = file.path
        doc_manager = DocumentManager(notebook_path)
        doc_manager.load_document()
        doc_manager.initialize_retriever()
        cl.user_session.set("docs", doc_manager.get_documents())
        cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))

        # Initialize LangGraph chain with the retrieval chain
        retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
        cl.user_session.set("retrieval_chain", retrieval_chain)
        tutor_chain = create_tutor_chain(retrieval_chain)
        cl.user_session.set("tutor_chain", tutor_chain)

        logger.info("Chat started and notebook uploaded successfully.")

@cl.on_message
async def main(message: cl.Message):
    # Retrieve the LangGraph chain from the session
    tutor_chain = cl.user_session.get("tutor_chain")

    if not tutor_chain:
        await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
        return

    # Create the initial state with the user message
    user_message = message.content
    state = TutorState(
        messages=[HumanMessage(content=user_message)],
        next="supervisor",
        quiz=[],
        quiz_created=False,
        question_answered=False,
        flashcards_created=False,
        flashcard_filename="",
    )

    print(f"Initial state: {state}")

    # Process the message through the LangGraph chain
    for s in tutor_chain.stream(state, {"recursion_limit": 10}):
        print(f"State after processing: {s}")

        # Extract messages from the state
        if "__end__" not in s:
            agent_state = next(iter(s.values()))
            if "messages" in agent_state:
                response = agent_state["messages"][-1].content
                print(f"Response: {response}")
                await cl.Message(content=response).send()
            else:
                print("Error: No messages found in agent state.")
        else:
            # Extract the final state
            final_state = next(iter(s.values()))

            # Check if the quiz was created and send it to the frontend
            if final_state.get("quiz_created"):
                quiz_message = final_state["messages"][-1].content
                await cl.Message(content=quiz_message).send()

            # Check if a question was answered and send the response to the frontend
            if final_state.get("question_answered"):
                qa_message = final_state["messages"][-1].content
                await cl.Message(content=qa_message).send()

            # Check if flashcards are ready and send the file to the frontend
            if final_state.get("flashcards_created"):
                flashcards_message = final_state["messages"][-1].content
                await cl.Message(content=flashcards_message).send()

                # Create a full path to the file
                flashcard_filename = final_state["flashcard_filename"]
                print(f"Flashcard filename: {flashcard_filename}")
                flashcard_path = os.path.abspath(flashcard_filename)
                print(f"Flashcard path: {flashcard_path}")

                # Use the File class to send the file
                file_element = cl.File(name=os.path.basename(flashcard_filename), path=flashcard_path)
                print(f"Sending flashcards file: {file_element}")
                await cl.Message(
                    content="Here are your flashcards:",
                    elements=[file_element]
                ).send()

            print("Reached END state.")

            break