File size: 3,216 Bytes
0f64bae
7b47aa3
48d9af7
 
 
 
 
 
 
 
 
 
 
 
 
c21a510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d9af7
 
 
 
 
 
 
 
 
19e42bb
0f64bae
 
 
7b47aa3
48d9af7
 
 
 
 
 
 
0f64bae
 
 
 
48d9af7
 
 
 
 
 
 
 
 
0f64bae
 
c21a510
 
 
 
 
 
 
 
 
 
 
 
0f64bae
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Optional, Type, Annotated
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
import csv
import uuid
import os

class FlashcardInput(BaseModel):
    flashcards: list = Field(description="A list of flashcards. Each flashcard should be a dictionary with 'question' and 'answer' keys.")

class FlashcardTool(BaseTool):
    """
    FlashcardTool class.

    This class represents a tool for creating flashcards in a .csv format suitable for import into Anki.

    Attributes:
        name (str): The name of the tool.
        description (str): The description of the tool.
        args_schema (Type[BaseModel]): The schema for the input arguments of the tool.

    Methods:
        _run(flashcards: list, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
            Use the tool to create flashcards.

        _arun(flashcards: list, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
            Use the tool asynchronously.
    """
    name = "create_flashcards"
    description = "Create flashcards in a .csv format suitable for import into Anki"
    args_schema: Type[BaseModel] = FlashcardInput

    def _run(
        self, flashcards: list, run_manager: Optional[CallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool to create flashcards."""
        filename = f"flashcards_{uuid.uuid4()}.csv"

        save_path = os.path.join('flashcards', filename)

        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        with open(save_path, 'w', newline='') as csvfile:
            fieldnames = ['Front', 'Back']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

            writer.writeheader()
            for card in flashcards:
                writer.writerow({'Front': card['question'], 'Back': card['answer']})

        print("\033[93m" + f"Flashcards successfully created and saved to {save_path}" + "\033[0m")

        return "csv file created successfully."

    async def _arun(
        self, flashcards: list, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
    ) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("create_flashcards does not support async")

# Instantiate the tool
create_flashcards_tool = FlashcardTool()

class RetrievalChainWrapper:
    """
    RetrievalChainWrapper class.

    This class wraps a retrieval chain and provides a method to retrieve information using the wrapped chain.

    Attributes:
        retrieval_chain: The retrieval chain to be wrapped.

    Methods:
        retrieve_information(query: str) -> str:
            Use this tool to retrieve information about the provided notebook.
    """
    def __init__(self, retrieval_chain):
        self.retrieval_chain = retrieval_chain

    def retrieve_information(
        self,
        query: Annotated[str, "query to ask the RAG tool"]
    ):
        """Use this tool to retrieve information about the provided notebook."""
        response = self.retrieval_chain.invoke({"question": query})
        return response["response"].content