JulsdL's picture
Code documentation and dockerizing AI Notebook Tutor
c21a510
raw
history blame
3.22 kB
from typing import Optional, Type, Annotated
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
import csv
import uuid
import os
class FlashcardInput(BaseModel):
flashcards: list = Field(description="A list of flashcards. Each flashcard should be a dictionary with 'question' and 'answer' keys.")
class FlashcardTool(BaseTool):
"""
FlashcardTool class.
This class represents a tool for creating flashcards in a .csv format suitable for import into Anki.
Attributes:
name (str): The name of the tool.
description (str): The description of the tool.
args_schema (Type[BaseModel]): The schema for the input arguments of the tool.
Methods:
_run(flashcards: list, run_manager: Optional[CallbackManagerForToolRun] = None) -> str:
Use the tool to create flashcards.
_arun(flashcards: list, run_manager: Optional[AsyncCallbackManagerForToolRun] = None) -> str:
Use the tool asynchronously.
"""
name = "create_flashcards"
description = "Create flashcards in a .csv format suitable for import into Anki"
args_schema: Type[BaseModel] = FlashcardInput
def _run(
self, flashcards: list, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Use the tool to create flashcards."""
filename = f"flashcards_{uuid.uuid4()}.csv"
save_path = os.path.join('flashcards', filename)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'w', newline='') as csvfile:
fieldnames = ['Front', 'Back']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for card in flashcards:
writer.writerow({'Front': card['question'], 'Back': card['answer']})
print("\033[93m" + f"Flashcards successfully created and saved to {save_path}" + "\033[0m")
return "csv file created successfully."
async def _arun(
self, flashcards: list, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("create_flashcards does not support async")
# Instantiate the tool
create_flashcards_tool = FlashcardTool()
class RetrievalChainWrapper:
"""
RetrievalChainWrapper class.
This class wraps a retrieval chain and provides a method to retrieve information using the wrapped chain.
Attributes:
retrieval_chain: The retrieval chain to be wrapped.
Methods:
retrieve_information(query: str) -> str:
Use this tool to retrieve information about the provided notebook.
"""
def __init__(self, retrieval_chain):
self.retrieval_chain = retrieval_chain
def retrieve_information(
self,
query: Annotated[str, "query to ask the RAG tool"]
):
"""Use this tool to retrieve information about the provided notebook."""
response = self.retrieval_chain.invoke({"question": query})
return response["response"].content