File size: 5,915 Bytes
48d9af7
 
c4eb0c2
 
 
 
deeba11
48d9af7
0f64bae
c4eb0c2
 
 
 
48d9af7
 
 
 
 
c4eb0c2
 
 
0f64bae
c4eb0c2
 
 
 
 
 
19e42bb
c4eb0c2
 
 
 
 
19e42bb
c4eb0c2
 
 
 
 
4e8b6ab
c4eb0c2
 
 
 
 
 
 
 
deeba11
 
4e8b6ab
48d9af7
 
 
c21a510
 
 
0f64bae
 
c21a510
 
 
 
deeba11
c4eb0c2
 
c21a510
 
 
 
 
 
19e42bb
deeba11
48d9af7
deeba11
48d9af7
c4eb0c2
 
 
deeba11
 
48d9af7
 
 
 
 
 
 
 
deeba11
0f64bae
deeba11
 
0f64bae
 
c4eb0c2
7b47aa3
48d9af7
7b47aa3
 
 
0f64bae
4e8b6ab
 
7b47aa3
 
 
0f64bae
7b47aa3
 
 
 
 
0f64bae
48d9af7
 
0f64bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c21a510
 
 
 
 
 
0f64bae
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import logging
import chainlit as cl
from dotenv import load_dotenv
from document_processing import DocumentManager
from retrieval import RetrievalManager
from langchain_core.messages import AIMessage, HumanMessage
from graph import create_tutor_chain, TutorState
import shutil

# Load environment variables
load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)

@cl.on_chat_start
async def start_chat():
    settings = {
        "model": "gpt4o",
        "temperature": 0,
        "top_p": 1,
        "frequency_penalty": 0,
        "presence_penalty": 0,
    }
    cl.user_session.set("settings", settings)
    welcome_message = "Welcome to the Notebook-Tutor!"
    await cl.Message(content=welcome_message).send()

    files = None
    while files is None:
        files = await cl.AskFileMessage(
            content="Please upload a Jupyter notebook (.ipynb, max. 5mb) to start:",
            accept={"application/x-ipynb+json": [".ipynb"]},
            max_size_mb=5
        ).send()

    file = files[0]  # Get the first file

    if file:
        notebook_path = file.path
        doc_manager = DocumentManager(notebook_path)
        doc_manager.load_document()
        doc_manager.initialize_retriever()
        cl.user_session.set("docs", doc_manager.get_documents())
        cl.user_session.set("retrieval_manager", RetrievalManager(doc_manager.get_retriever()))

        # Initialize LangGraph chain with the retrieval chain
        retrieval_chain = cl.user_session.get("retrieval_manager").get_RAG_QA_chain()
        cl.user_session.set("retrieval_chain", retrieval_chain)
        tutor_chain = create_tutor_chain(retrieval_chain)
        cl.user_session.set("tutor_chain", tutor_chain)

        logger.info("Chat started and notebook uploaded successfully.")

        ready_to_chat_message = "Notebook uploaded and processed successfully!"
        await cl.Message(content=ready_to_chat_message).send()

        invite_message = "You can now ask questions or request quizzes and flashcards based on the notebook content."
        await cl.Message(content=invite_message).send()



@cl.on_message
async def main(message: cl.Message):
    """
    This is the main function that processes a message through the LangGraph chain.

    Parameters:
    - message (cl.Message): The message to be processed.
    """

    # Retrieve the LangGraph chain from the session
    tutor_chain = cl.user_session.get("tutor_chain")

    if not tutor_chain:
        await cl.Message(content="No document processing setup found. Please upload a Jupyter notebook first.").send()
        return

    # Create the initial state with the user message
    user_message = message.content
    state = TutorState(
        messages=[HumanMessage(content=user_message)],
        next="supervisor",
        quiz=[],
        quiz_created=False,
        question_answered=False,
        flashcards_created=False,
    )

    logger.info(f"Initial state: {state}")

    # Process the message through the LangGraph chain
    for s in tutor_chain.stream(state, {"recursion_limit": 10}):
        logger.info(f"State after processing: {s}")

        agent_state = next(iter(s.values()))

        if "QAAgent" in s:
            if s['QAAgent']['question_answered']:
                qa_message = agent_state["messages"][-1].content
                logger.info(f"Sending QAAgent message: {qa_message}")
                await cl.Message(content=qa_message).send()

        if "QuizAgent" in s:
            if s['QuizAgent']['quiz_created']:
                quiz_message = agent_state["messages"][-1].content
                logger.info(f"Sending QuizAgent message: {quiz_message}")
                await cl.Message(content=quiz_message).send()

        if "FlashcardsAgent" in s:
            if s['FlashcardsAgent']['flashcards_created']:
                flashcards_message = agent_state["messages"][-1].content
                logger.info(f"Sending FlashcardsAgent message: {flashcards_message}")
                await cl.Message(content=flashcards_message).send()

                # Search for the flashcard file in the specified directory
                flashcard_directory = 'flashcards'
                flashcard_file = None
                latest_time = 0
                for root, dirs, files in os.walk(flashcard_directory):
                    for file in files:
                        if file.startswith('flashcards_') and file.endswith('.csv'):
                            file_path = os.path.join(root, file)
                            file_time = os.path.getmtime(file_path)
                            if file_time > latest_time:
                                latest_time = file_time
                                flashcard_file = file_path

                if flashcard_file:
                    logger.info(f"Flashcard path: {flashcard_file}")
                    # Use the File class to send the file
                    file_element = cl.File(name="Flashcards", path=flashcard_file, display="inline")
                    logger.info(f"Sending flashcards file: {file_element}")

                    await cl.Message(
                        content="Download the flashcards in .csv here:",
                        elements=[file_element]
                    ).send()

    logger.info("Reached END state.")


@cl.on_chat_end
async def end_chat():
    """
    Clean up the flashcards directory after the chat ends.
    This function is executed when the chat session ends.
    It removes the 'flashcards' directory and all its contents, if it exists.
    If the directory does not exist, it creates a new empty directory with the same name.
    """
    # Clean up the flashcards directory
    flashcard_directory = 'flashcards'
    if os.path.exists(flashcard_directory):
        shutil.rmtree(flashcard_directory)
        os.makedirs(flashcard_directory)